From 711868c8eee0f1df97d3a221c44e42b2fbe5c86d Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 19 May 2025 16:05:34 +0200 Subject: [PATCH 1/3] Fast inline lambdas --- src/gt4py/next/ffront/past_to_itir.py | 2 +- .../iterator/transforms/inline_lambdas.py | 60 ++++++------------- .../next/iterator/transforms/remap_symbols.py | 53 ++++++++++++++-- .../transforms_tests/test_inline_lambdas.py | 13 ++++ 4 files changed, 80 insertions(+), 48 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 8b026de487..11b8b08f7d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -99,7 +99,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: itir_program.params[i].id: im.literal_from_tuple_value(value) for i, value in static_args_index.items() } - body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) + body = remap_symbols.RemapSymbolRefs.apply(itir_program.body, symbol_map=static_args) itir_program = itir.Program( id=itir_program.id, function_definitions=itir_program.function_definitions, diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 9053214b39..2d61c65f4e 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -11,9 +11,10 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift -from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols -from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.transforms import symbol_ref_utils +from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs from gt4py.next.iterator.type_system import inference as itir_inference @@ -33,7 +34,9 @@ def inline_lambda( # see todo above assert len(eligible_params) == len(node.fun.params) == len(node.args) if opcount_preserving: - ref_counts = CountSymbolRefs.apply(node.fun.expr, [p.id for p in node.fun.params]) + ref_counts = symbol_ref_utils.CountSymbolRefs.apply( + node.fun.expr, [p.id for p in node.fun.params] + ) for i, param in enumerate(node.fun.params): # TODO(tehrengruber): allow inlining more complicated zero-op expressions like ignore_shift(...)(it_sym) @@ -63,53 +66,24 @@ def inline_lambda( # see todo above if node.fun.params and not any(eligible_params): return node - refs = set().union( - *( - arg.pre_walk_values().if_isinstance(ir.SymRef).getattr("id").to_set() - for arg, eligible in zip(node.args, eligible_params) - if eligible - ) - ) - syms = node.fun.expr.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() - clashes = refs & syms - expr = node.fun.expr - if clashes: - # TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work: - # (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code] - name_map: dict[ir.SymRef, str] = {} - - def new_name(name): - while name in refs or name in syms or name in name_map.values(): - name += "_" - return name - - for sym in clashes: - name_map[sym] = new_name(sym) - - expr = RenameSymbols().visit(expr, name_map=name_map) - symbol_map = { - param.id: arg + str(param.id): arg for param, arg, eligible in zip(node.fun.params, node.args, eligible_params) if eligible } - new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) + + new_fun_proto = im.lambda_( + *(param for param, eligible in zip(node.fun.params, eligible_params) if not eligible) + )(node.fun.expr) + new_fun_proto = RemapSymbolRefs.apply(new_fun_proto, symbol_map=symbol_map) + new_expr = im.call(new_fun_proto)( + *(arg for arg, eligible in zip(node.args, eligible_params) if not eligible) + ) if all(eligible_params): + new_expr = new_expr.fun.expr new_expr.location = node.location - else: - new_expr = ir.FunCall( - fun=ir.Lambda( - params=[ - param - for param, eligible in zip(node.fun.params, eligible_params) - if not eligible - ], - expr=new_expr, - ), - args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], - location=node.location, - ) + for attr in ("type", "recorded_shifts", "domain"): if hasattr(node.annex, attr): setattr(new_expr.annex, attr, getattr(node.annex, attr)) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index fb909dc5d0..053fc95d2a 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,20 +10,65 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.type_system import inference as type_inference +def unique_name(name, prohibited_symbols): + while name in prohibited_symbols: + name += "_" + return name + + class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") - def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): + @classmethod + def apply(cls, node: ir.Node, symbol_map: Dict[str, ir.Node]): + external_symbols = set().union( + *(symbol_ref_utils.collect_symbol_refs(expr) for expr in [node, *symbol_map.values()]) + ) + return cls().visit(node, symbol_map=symbol_map, reserved_params=external_symbols) + + def visit_SymRef( + self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str] + ): return symbol_map.get(str(node.id), node) - def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): + def visit_Lambda( + self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str] + ): params = {str(p.id) for p in node.params} - new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params} - return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map)) + + clashes = params & reserved_params + if clashes: + reserved_params = {*reserved_params} + new_symbol_map: Dict[str, ir.Node] = {} + new_params: list[ir.Sym] = [] + for param in node.params: + if param.id in clashes: + new_param = im.sym(unique_name(param.id, reserved_params), param.type) + assert new_param.id not in symbol_map + new_symbol_map[param.id] = im.ref(new_param.id, param.type) + reserved_params.add(new_param.id) + else: + new_param = param + new_params.append(new_param) + + new_symbol_map = symbol_map | new_symbol_map + else: + new_params = node.params # keep params as is + new_symbol_map = symbol_map + + filtered_symbol_map = {k: v for k, v in new_symbol_map.items() if k not in new_params} + return ir.Lambda( + params=new_params, + expr=self.visit( + node.expr, symbol_map=filtered_symbol_map, reserved_params=reserved_params + ), + ) def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] assert isinstance(node, SymbolTableTrait) == isinstance( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index c10d48ad06..29eb5b1d82 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -40,6 +40,19 @@ ), im.multiplies_(im.plus(2, 1), im.plus("x", "x")), ), + ( + "name_shadowing_external", + im.call(im.lambda_("x")(im.lambda_("y")(im.plus("x", "y"))))(im.plus("x", "y")), + im.lambda_("y_")(im.plus(im.plus("x", "y"), "y_")), + ), + ( + "renaming_collision", + # the `y` param of the lambda may not be renamed to `y_` as this name is already referenced + im.call(im.lambda_("x")(im.lambda_("y")(im.plus(im.plus("x", "y"), "y_"))))( + im.plus("x", "y") + ), + im.lambda_("y__")(im.plus(im.plus(im.plus("x", "y"), "y__"), "y_")), + ), ( # ensure opcount preserving option works whether `itir.SymRef` has a type or not "typed_ref", From 4d90dec468cccded5ba2c03c2f625c2ca509c256 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 19 May 2025 16:07:25 +0200 Subject: [PATCH 2/3] Small fixes --- src/gt4py/next/ffront/past_to_itir.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 11b8b08f7d..31fae8a0a9 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -96,10 +96,13 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: i: arg.value for i, arg in enumerate(inp.args.args) if isinstance(arg, arguments.StaticArg) } static_args = { - itir_program.params[i].id: im.literal_from_tuple_value(value) + str(itir_program.params[i].id): im.literal_from_tuple_value(value) for i, value in static_args_index.items() } - body = remap_symbols.RemapSymbolRefs.apply(itir_program.body, symbol_map=static_args) + body = [ + remap_symbols.RemapSymbolRefs.apply(stmt, symbol_map=static_args) + for stmt in itir_program.body + ] itir_program = itir.Program( id=itir_program.id, function_definitions=itir_program.function_definitions, From 3f411f7e4714992f153bd60bf92000dc772f0260 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 19 May 2025 16:09:45 +0200 Subject: [PATCH 3/3] Small fixes --- src/gt4py/next/ffront/past_to_itir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 31fae8a0a9..e6624ce77e 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -102,7 +102,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: body = [ remap_symbols.RemapSymbolRefs.apply(stmt, symbol_map=static_args) for stmt in itir_program.body - ] + ] # type: ignore[arg-type] itir_program = itir.Program( id=itir_program.id, function_definitions=itir_program.function_definitions,