Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram:
for name, descr in static_arg_descriptors.items()
if not any(el is None for el in gtx_utils.flatten_nested_tuple(descr)) # type: ignore[arg-type]
}
body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args)
body = [
remap_symbols.RemapSymbolRefs.apply(stmt, symbol_map=static_args)
for stmt in itir_program.body
] # type: ignore[arg-type]
itir_program = itir.Program(
id=itir_program.id,
function_definitions=itir_program.function_definitions,
Expand Down
61 changes: 19 additions & 42 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import Mapping, Optional, TypeVar
from typing import Optional, TypeVar, Mapping

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols
from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs
from gt4py.next.iterator.transforms import symbol_ref_utils
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs
from gt4py.next.iterator.type_system import inference as itir_inference


Expand All @@ -34,7 +34,9 @@ def inline_lambda( # see todo above
assert len(eligible_params) == len(node.fun.params) == len(node.args)

if opcount_preserving:
ref_counts = CountSymbolRefs.apply(node.fun.expr, [p.id for p in node.fun.params])
ref_counts = symbol_ref_utils.CountSymbolRefs.apply(
node.fun.expr, [p.id for p in node.fun.params]
)

for i, param in enumerate(node.fun.params):
# TODO(tehrengruber): allow inlining more complicated zero-op expressions like ignore_shift(...)(it_sym)
Expand Down Expand Up @@ -64,49 +66,24 @@ def inline_lambda( # see todo above
if node.fun.params and not any(eligible_params):
return node

refs: set[str] = set().union(
*(
arg.pre_walk_values().if_isinstance(ir.SymRef).getattr("id").to_set()
for arg, eligible in zip(node.args, eligible_params)
if eligible
)
)
syms: set[str] = node.fun.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set()
clashes = refs & syms
fun = node.fun
if clashes:
# TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work:
# (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code]
name_map: dict[str, str] = {}

for sym in clashes:
name_map[sym] = ir_misc.unique_symbol(sym, refs | syms | {*name_map.values()})

# Let's rename the symbols (including params) of the function.
# If we would like to preserve the original param names, we could alternatively
# rename the eligible symrefs in `args`.
fun = RenameSymbols().visit(fun, name_map=name_map)

symbol_map = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leaving a remark here, that there was #2134. Maybe it's already resolved here.

param.id: arg
for param, arg, eligible in zip(fun.params, node.args, eligible_params)
str(param.id): arg
for param, arg, eligible in zip(node.fun.params, node.args, eligible_params)
if eligible
}
new_expr = RemapSymbolRefs().visit(fun.expr, symbol_map=symbol_map)

new_fun_proto = im.lambda_(
*(param for param, eligible in zip(node.fun.params, eligible_params) if not eligible)
)(node.fun.expr)
new_fun_proto = RemapSymbolRefs.apply(new_fun_proto, symbol_map=symbol_map)
new_expr = im.call(new_fun_proto)(
*(arg for arg, eligible in zip(node.args, eligible_params) if not eligible)
)

if all(eligible_params):
new_expr = new_expr.fun.expr
new_expr.location = node.location
else:
new_expr = ir.FunCall(
fun=ir.Lambda(
params=[
param for param, eligible in zip(fun.params, eligible_params) if not eligible
],
expr=new_expr,
),
args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible],
location=node.location,
)

for attr in ("type", "recorded_shifts", "domain"):
if hasattr(node.annex, attr):
setattr(new_expr.annex, attr, getattr(node.annex, attr))
Expand Down
47 changes: 43 additions & 4 deletions src/gt4py/next/iterator/transforms/remap_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,59 @@

from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc
from gt4py.next.iterator.transforms import symbol_ref_utils
from gt4py.next.iterator.type_system import inference as type_inference


class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator):
# This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex.
PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain")

def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]):
@classmethod
def apply(cls, node: ir.Node, symbol_map: Dict[str, ir.Node]):
external_symbols = set().union(
*(symbol_ref_utils.collect_symbol_refs(expr) for expr in [node, *symbol_map.values()])
)
return cls().visit(node, symbol_map=symbol_map, reserved_params=external_symbols)

def visit_SymRef(
self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str]
):
return symbol_map.get(str(node.id), node)

def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]):
def visit_Lambda(
self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str]
):
params = {str(p.id) for p in node.params}
new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params}
return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map))

clashes = params & reserved_params
if clashes:
reserved_params = {*reserved_params}
new_symbol_map: Dict[str, ir.Node] = {}
new_params: list[ir.Sym] = []
for param in node.params:
if param.id in clashes:
new_param = im.sym(ir_misc.unique_symbol(param.id, reserved_params), param.type)
assert new_param.id not in symbol_map
new_symbol_map[param.id] = im.ref(new_param.id, param.type)
reserved_params.add(new_param.id)
else:
new_param = param
new_params.append(new_param)

new_symbol_map = symbol_map | new_symbol_map
else:
new_params = node.params # keep params as is
new_symbol_map = symbol_map

filtered_symbol_map = {k: v for k, v in new_symbol_map.items() if k not in new_params}
return ir.Lambda(
params=new_params,
expr=self.visit(
node.expr, symbol_map=filtered_symbol_map, reserved_params=reserved_params
),
)

def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override]
assert isinstance(node, SymbolTableTrait) == isinstance(node, ir.Lambda), (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@
),
im.multiplies_(im.plus(2, 1), im.plus("x", "x")),
),
(
"name_shadowing_external",
im.call(im.lambda_("x")(im.lambda_("y")(im.plus("x", "y"))))(im.plus("x", "y")),
im.lambda_("y_")(im.plus(im.plus("x", "y"), "y_")),
),
(
"renaming_collision",
# the `y` param of the lambda may not be renamed to `y_` as this name is already referenced
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Is this the same problem as in #2134

im.call(im.lambda_("x")(im.lambda_("y")(im.plus(im.plus("x", "y"), "y_"))))(
im.plus("x", "y")
),
im.lambda_("y__")(im.plus(im.plus(im.plus("x", "y"), "y__"), "y_")),
),
(
# ensure opcount preserving option works whether `itir.SymRef` has a type or not
"typed_ref",
Expand Down
Loading