Skip to content

Commit ef9ef92

Browse files
committed
edit
1 parent f8180e2 commit ef9ef92

3 files changed

Lines changed: 45 additions & 57 deletions

File tree

src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
from gt4py.next import common as gtx_common, utils as gtx_utils
1818
from gt4py.next.iterator import ir as gtir
19-
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils
19+
from gt4py.next.iterator.ir_utils import (
20+
common_pattern_matcher as cpm,
21+
domain_utils,
22+
ir_makers as im,
23+
)
2024
from gt4py.next.program_processors.runners.dace import (
2125
gtir_dataflow,
2226
gtir_domain,
@@ -247,10 +251,16 @@ def translate_as_fieldop(
247251
raise NotImplementedError("Unexpected 'as_filedop' with tuple output in SDFG lowering.")
248252

249253
if cpm.is_ref_to(fieldop_expr, "deref"):
250-
# Special usage of 'deref' as argument to fieldop expression, to pass a scalar
251-
# value to 'as_fieldop' function. It results in broadcasting the scalar value
252-
# over the field domain.
253-
return translate_broadcast(node, ctx, sdfg_builder)
254+
if isinstance(node.args[0].type, ts.ScalarType):
255+
# Special usage of 'deref' as argument to fieldop expression, to broadcast
256+
# a scalar value on the field domain.
257+
return translate_broadcast(node, ctx, sdfg_builder)
258+
else:
259+
# Special usage of 'deref' with field argument, to return a subset of
260+
# the full field domain.
261+
# TODO(edopao): Lower this case to a memlet edge, planned for next PR.
262+
stencil_expr = im.lambda_("a")(im.deref("a"))
263+
stencil_expr.expr.type = node.type.dtype
254264
elif isinstance(fieldop_expr, gtir.Lambda):
255265
# Default case, handled below: the argument expression is a lambda function
256266
# representing the stencil operation to be computed over the field domain.
@@ -285,6 +295,7 @@ def translate_broadcast(
285295
ctx: gtir_to_sdfg.SubgraphContext,
286296
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
287297
) -> gtir_to_sdfg_types.FieldopData:
298+
"""Translates a broadcast expression which writes a scalar value on the field domain."""
288299
assert isinstance(node, gtir.FunCall)
289300
assert cpm.is_call_to(node.fun, "as_fieldop")
290301

@@ -294,6 +305,10 @@ def translate_broadcast(
294305
assert isinstance(node.type.dtype, ts.ScalarType)
295306
field_dtype = gtx_dace_utils.as_dace_type(node.type.dtype)
296307

308+
assert len(node.args) == 1
309+
assert isinstance(node.args[0].type, ts.ScalarType)
310+
scalar_arg = node.args[0]
311+
297312
fun_node = node.fun
298313
assert len(fun_node.args) == 2
299314
fieldop_expr, fieldop_domain_expr = fun_node.args
@@ -308,32 +323,23 @@ def translate_broadcast(
308323
# The memory layout of the output field follows the field operator compute domain.
309324
field_dims, field_origin, field_shape = gtir_domain.get_field_layout(field_domain)
310325
assert field_dims == node.type.dims
311-
312326
field_name, field_desc = sdfg_builder.add_temp_array(ctx.sdfg, field_shape, field_dtype)
313327
field_node = ctx.state.add_access(field_name)
314328

315329
# Retrieve the scalar argument, which could be either a literal value or the
316330
# result of a scalar expression.
317-
assert len(node.args) == 1
331+
arg = _parse_fieldop_arg(scalar_arg, ctx, sdfg_builder, field_domain)
332+
assert isinstance(arg, gtir_dataflow.MemletExpr)
333+
assert arg.subset.num_elements() == 1
318334

319335
# Use a 'Fill' library node to write the scalar value to the result field.
320-
if isinstance(node.args[0], gtir.Literal):
321-
assert node.args[0].type == node.type.dtype
322-
value = field_dtype(node.args[0].value)
323-
fill_node = sdfg_library_nodes.Fill("fill", value)
324-
ctx.state.add_node(fill_node)
325-
else:
326-
arg = _parse_fieldop_arg(node.args[0], ctx, sdfg_builder, field_domain)
327-
assert isinstance(arg, gtir_dataflow.MemletExpr)
328-
assert arg.subset.num_elements() == 1
329-
330-
fill_node = sdfg_library_nodes.Fill("fill")
331-
ctx.state.add_node(fill_node)
332-
ctx.state.add_nedge(
333-
arg.dc_node,
334-
fill_node,
335-
dace.Memlet(data=arg.dc_node.data, subset=arg.subset),
336-
)
336+
fill_node = sdfg_library_nodes.Fill("fill")
337+
ctx.state.add_node(fill_node)
338+
ctx.state.add_nedge(
339+
arg.dc_node,
340+
fill_node,
341+
dace.Memlet(data=arg.dc_node.data, subset=arg.subset),
342+
)
337343

338344
ctx.state.add_nedge(
339345
fill_node,

src/gt4py/next/program_processors/runners/dace/sdfg_library_nodes.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,16 @@ def expansion(node: Fill, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG)
4444
out_mem = dace.Memlet(expr=f"{out}[{','.join(map_params)}]")
4545
outputs = {"_out": out_mem}
4646

47-
if node._value is None:
48-
assert len(parent_state.in_edges(node)) == 1
49-
inedge = parent_state.in_edges(node)[0]
50-
inp_desc = parent_sdfg.arrays[inedge.data.data]
51-
inner_inp_desc = inp_desc.clone()
52-
inner_inp_desc.transient = False
53-
inp = sdfg.add_datadesc(_INPUT_NAME, inner_inp_desc)
54-
inedge._dst_conn = _INPUT_NAME
55-
node.add_in_connector(_INPUT_NAME)
56-
inputs = {"_in": dace.Memlet(data=inp, subset="0")}
57-
code = "_out = _in"
58-
else:
59-
inputs = {}
60-
code = f"_out = {node._value}"
47+
assert len(parent_state.in_edges(node)) == 1
48+
inedge = parent_state.in_edges(node)[0]
49+
inp_desc = parent_sdfg.arrays[inedge.data.data]
50+
inner_inp_desc = inp_desc.clone()
51+
inner_inp_desc.transient = False
52+
inp = sdfg.add_datadesc(_INPUT_NAME, inner_inp_desc)
53+
inedge._dst_conn = _INPUT_NAME
54+
node.add_in_connector(_INPUT_NAME)
55+
inputs = {"_in": dace.Memlet(data=inp, subset="0")}
56+
code = "_out = _in"
6157

6258
state.add_mapped_tasklet(
6359
f"{node.label}_tasklet", map_rng, inputs, code, outputs, external_edges=True
@@ -72,8 +68,6 @@ class Fill(dace_nodes.LibraryNode):
7268

7369
implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {"pure": ExpandPure}
7470
default_implementation: Final[str] = "pure"
75-
_value: dace.typeclass | None
7671

77-
def __init__(self, name: str, value: dace.typeclass | None = None):
72+
def __init__(self, name: str):
7873
super().__init__(name)
79-
self._value = value

tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,16 +2110,8 @@ def test_gtir_concat_where():
21102110
gtx_common.GridType.CARTESIAN, {IDim: (SUBSET_SIZE, gtir.InfinityLiteral.POSITIVE)}
21112111
)
21122112

2113-
concat_expr_lhs = im.concat_where(
2114-
domain_cond_lhs,
2115-
im.as_fieldop("deref")("x"),
2116-
im.as_fieldop("deref")("y"),
2117-
)
2118-
concat_expr_rhs = im.concat_where(
2119-
domain_cond_rhs,
2120-
im.as_fieldop("deref")("y"),
2121-
im.as_fieldop("deref")("x"),
2122-
)
2113+
concat_expr_lhs = im.concat_where(domain_cond_lhs, "x", "y")
2114+
concat_expr_rhs = im.concat_where(domain_cond_rhs, "y", "x")
21232115

21242116
a = np.random.rand(N)
21252117
b = np.random.rand(N)
@@ -2177,12 +2169,8 @@ def test_gtir_concat_where_two_dimensions():
21772169
gtir.SetAt(
21782170
expr=im.concat_where(
21792171
domain_cond1, # 0, 30; 10,20
2180-
im.concat_where(
2181-
domain_cond2,
2182-
im.as_fieldop("deref")("x"),
2183-
im.as_fieldop("deref")("y"),
2184-
),
2185-
im.as_fieldop("deref")("w"),
2172+
im.concat_where(domain_cond2, "x", "y"),
2173+
"w",
21862174
),
21872175
domain=domain,
21882176
target=gtir.SymRef(id="z"),

0 commit comments

Comments
 (0)