Skip to content

Commit c6f1659

Browse files
authored
refactor[next][dace]: Introduce dataclass for field operator domain range (#2142)
This PR introduces a dataclass type to represent the domain range in the lowering to SDFG. It follows a suggestion from @philip-paul-mueller in the review of #2137.
1 parent 0cad969 commit c6f1659

6 files changed

Lines changed: 116 additions & 72 deletions

File tree

src/gt4py/next/program_processors/runners/dace/gtir_domain.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
import dataclasses
1112
from typing import Optional, Sequence, TypeAlias
1213

1314
import dace
@@ -19,15 +20,24 @@
1920
from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils
2021

2122

22-
FieldopDomain: TypeAlias = list[
23-
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
24-
]
25-
"""
26-
Domain of a field operator represented as a list of tuples with 3 elements:
27-
- dimension definition
28-
- symbolic expression for lower bound (inclusive)
29-
- symbolic expression for upper bound (exclusive)
30-
"""
23+
@dataclasses.dataclass(frozen=True)
24+
class FieldopDomainRange:
25+
"""
26+
Represents the range of a field operator domain in one dimension.
27+
28+
It contains 3 elements:
29+
dim: dimension definition
30+
start: symbolic expression for lower bound (inclusive)
31+
stop: symbolic expression for upper bound (exclusive)
32+
"""
33+
34+
dim: gtx_common.Dimension
35+
start: dace.symbolic.SymbolicType
36+
stop: dace.symbolic.SymbolicType
37+
38+
39+
FieldopDomain: TypeAlias = list[FieldopDomainRange]
40+
"""Domain of a field operator represented as a list of `FieldopDomainRange` for each dimension."""
3141

3242

3343
def extract_domain(node: gtir.Expr) -> FieldopDomain:
@@ -49,12 +59,12 @@ def extract_domain(node: gtir.Expr) -> FieldopDomain:
4959
gtir_to_sdfg_utils.get_symbolic(arg) for arg in named_range.args[1:3]
5060
)
5161
dim = gtx_common.Dimension(axis.value, axis.kind)
52-
domain.append((dim, lower_bound, upper_bound))
62+
domain.append(FieldopDomainRange(dim, lower_bound, upper_bound))
5363

5464
elif isinstance(node, domain_utils.SymbolicDomain):
5565
for dim, drange in node.ranges.items():
5666
domain.append(
57-
(
67+
FieldopDomainRange(
5868
dim,
5969
gtir_to_sdfg_utils.get_symbolic(drange.start),
6070
gtir_to_sdfg_utils.get_symbolic(drange.stop),
@@ -119,6 +129,7 @@ def get_field_layout(
119129
"""
120130
if len(domain) == 0:
121131
return [], [], []
122-
domain_dims, domain_lbs, domain_ubs = zip(*domain)
123-
domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)]
124-
return list(domain_dims), list(domain_lbs), domain_sizes
132+
domain_dims = [domain_range.dim for domain_range in domain]
133+
domain_origin = [domain_range.start for domain_range in domain]
134+
domain_shape = [(domain_range.stop - domain_range.start) for domain_range in domain]
135+
return domain_dims, domain_origin, domain_shape

src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ def _make_access_index_for_field(
197197
# since the access indices have to follow the order of dimensions in field domain
198198
if isinstance(data.gt_type, ts.FieldType) and len(data.gt_type.dims) != 0:
199199
assert data.origin is not None
200-
domain_ranges = {dim: (lb, ub) for dim, lb, ub in domain}
200+
domain_ranges = {
201+
domain_range.dim: (domain_range.start, domain_range.stop) for domain_range in domain
202+
}
201203
return dace.subsets.Range(
202204
(domain_ranges[dim][0] - origin, domain_ranges[dim][1] - origin - 1, 1)
203205
for dim, origin in zip(data.gt_type.dims, data.origin, strict=True)

src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -185,20 +185,20 @@ def _translate_concat_where_impl(
185185
gtir_domain.extract_domain(domain) for domain in [tb_node_domain, fb_node_domain]
186186
)
187187
assert len(mask_domain) == 1
188-
concat_dim, mask_lower_bound, mask_upper_bound = mask_domain[0]
188+
concat_domain = mask_domain[0]
189189

190190
# Expect unbound range in the concat domain expression on lower or upper range:
191191
# - if the domain expression is unbound on lower side (negative infinite),
192192
# the expression on the true branch is to be considered the input for the
193193
# lower domain.
194194
# - viceversa, if the domain expression is unbound on upper side (positive
195195
# infinite), the true expression represents the input for the upper domain.
196-
if mask_lower_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE):
197-
concat_dim_bound = mask_upper_bound
196+
if concat_domain.start == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE):
197+
concat_dim_bound = concat_domain.stop
198198
lower, lower_desc, lower_domain = (tb_field, tb_data_desc, tb_domain)
199199
upper, upper_desc, upper_domain = (fb_field, fb_data_desc, fb_domain)
200-
elif mask_upper_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE):
201-
concat_dim_bound = mask_lower_bound
200+
elif concat_domain.stop == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE):
201+
concat_dim_bound = concat_domain.start
202202
lower, lower_desc, lower_domain = (fb_field, fb_data_desc, fb_domain)
203203
upper, upper_desc, upper_domain = (tb_field, tb_data_desc, tb_domain)
204204
else:
@@ -207,9 +207,9 @@ def _translate_concat_where_impl(
207207
# we use the concat domain, stored in the annex, as the domain of output field
208208
output_domain = gtir_domain.extract_domain(node_domain)
209209
output_dims, output_origin, output_shape = _get_concat_where_field_layout(
210-
output_domain, concat_dim
210+
output_domain, concat_domain.dim
211211
)
212-
concat_dim_index = output_dims.index(concat_dim)
212+
concat_dim_index = output_dims.index(concat_domain.dim)
213213

214214
"""
215215
In case one of the arguments is a scalar value, for example:
@@ -225,23 +225,27 @@ def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField:
225225
assert isinstance(upper.gt_type, ts.FieldType)
226226
lower = gtir_to_sdfg_types.FieldopData(
227227
lower.dc_node,
228-
ts.FieldType(dims=[concat_dim], dtype=lower.gt_type),
228+
ts.FieldType(dims=[concat_domain.dim], dtype=lower.gt_type),
229229
origin=(concat_dim_bound - 1,),
230230
)
231-
lower_bound = output_domain[concat_dim_index][1]
232-
lower_domain = [(concat_dim, lower_bound, concat_dim_bound)]
231+
lower_bound = output_domain[concat_dim_index].start
232+
lower_domain = [
233+
gtir_domain.FieldopDomainRange(concat_domain.dim, lower_bound, concat_dim_bound)
234+
]
233235
elif isinstance(upper.gt_type, ts.ScalarType):
234236
assert len(upper_domain) == 0
235237
assert isinstance(lower.gt_type, ts.FieldType)
236238
upper = gtir_to_sdfg_types.FieldopData(
237239
upper.dc_node,
238-
ts.FieldType(dims=[concat_dim], dtype=upper.gt_type),
240+
ts.FieldType(dims=[concat_domain.dim], dtype=upper.gt_type),
239241
origin=(concat_dim_bound,),
240242
)
241-
upper_bound = output_domain[concat_dim_index][2]
242-
upper_domain = [(concat_dim, concat_dim_bound, upper_bound)]
243+
upper_bound = output_domain[concat_dim_index].stop
244+
upper_domain = [
245+
gtir_domain.FieldopDomainRange(concat_domain.dim, concat_dim_bound, upper_bound)
246+
]
243247

244-
if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr]
248+
if concat_domain.dim not in lower.gt_type.dims: # type: ignore[union-attr]
245249
"""
246250
The field on the lower domain is to be treated as a slice to add as one
247251
level in the concat dimension, on the lower bound.
@@ -261,13 +265,22 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
261265
]
262266
)
263267
lower, lower_desc = _make_concat_field_slice(
264-
sdfg, state, lower, lower_desc, concat_dim, concat_dim_index, concat_dim_bound - 1
268+
sdfg=sdfg,
269+
state=state,
270+
field=lower,
271+
field_desc=lower_desc,
272+
concat_dim=concat_domain.dim,
273+
concat_dim_index=concat_dim_index,
274+
concat_dim_origin=concat_dim_bound - 1,
265275
)
266276
lower_bound = dace.symbolic.pystr_to_symbolic(
267-
f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index][1]})"
277+
f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index].start})"
278+
)
279+
lower_domain.insert(
280+
concat_dim_index,
281+
gtir_domain.FieldopDomainRange(concat_domain.dim, lower_bound, concat_dim_bound),
268282
)
269-
lower_domain.insert(concat_dim_index, (concat_dim, lower_bound, concat_dim_bound))
270-
elif concat_dim not in upper.gt_type.dims: # type: ignore[union-attr]
283+
elif concat_domain.dim not in upper.gt_type.dims: # type: ignore[union-attr]
271284
# Same as previous case, but the field slice is added on the upper bound.
272285
assert (
273286
upper.gt_type.dims # type: ignore[union-attr]
@@ -277,12 +290,21 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
277290
]
278291
)
279292
upper, upper_desc = _make_concat_field_slice(
280-
sdfg, state, upper, upper_desc, concat_dim, concat_dim_index, concat_dim_bound
293+
sdfg=sdfg,
294+
state=state,
295+
field=upper,
296+
field_desc=upper_desc,
297+
concat_dim=concat_domain.dim,
298+
concat_dim_index=concat_dim_index,
299+
concat_dim_origin=concat_dim_bound,
281300
)
282301
upper_bound = dace.symbolic.pystr_to_symbolic(
283-
f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index][2]})"
302+
f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index].stop})"
303+
)
304+
upper_domain.insert(
305+
concat_dim_index,
306+
gtir_domain.FieldopDomainRange(concat_domain.dim, concat_dim_bound, upper_bound),
284307
)
285-
upper_domain.insert(concat_dim_index, (concat_dim, concat_dim_bound, upper_bound))
286308
elif isinstance(lower_desc, dace.data.Scalar) or (
287309
len(lower.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr]
288310
):
@@ -297,27 +319,37 @@ def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField:
297319
return concat_where(KDim == 0, a, b)
298320
```
299321
"""
300-
assert len(lower_domain) == 1 and lower_domain[0][0] == concat_dim
322+
assert len(lower_domain) == 1 and lower_domain[0].dim == concat_domain.dim
301323
lower_domain = [
302324
*output_domain[:concat_dim_index],
303325
lower_domain[0],
304326
*output_domain[concat_dim_index + 1 :],
305327
]
306328
lower, lower_desc = _make_concat_scalar_broadcast(
307-
sdfg, state, lower, lower_desc, lower_domain, concat_dim_index
329+
sdfg=sdfg,
330+
state=state,
331+
inp=lower,
332+
inp_desc=lower_desc,
333+
domain=lower_domain,
334+
concat_dim_index=concat_dim_index,
308335
)
309336
elif isinstance(upper_desc, dace.data.Scalar) or (
310337
len(upper.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr]
311338
):
312339
# Same as previous case, but the scalar value is taken from `upper` input.
313-
assert len(upper_domain) == 1 and upper_domain[0][0] == concat_dim
340+
assert len(upper_domain) == 1 and upper_domain[0].dim == concat_domain.dim
314341
upper_domain = [
315342
*output_domain[:concat_dim_index],
316343
upper_domain[0],
317344
*output_domain[concat_dim_index + 1 :],
318345
]
319346
upper, upper_desc = _make_concat_scalar_broadcast(
320-
sdfg, state, upper, upper_desc, upper_domain, concat_dim_index
347+
sdfg=sdfg,
348+
state=state,
349+
inp=upper,
350+
inp_desc=upper_desc,
351+
domain=upper_domain,
352+
concat_dim_index=concat_dim_index,
321353
)
322354
else:
323355
"""
@@ -341,15 +373,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
341373
# ensure that the arguments have the same domain as the concat result
342374
assert all(ftype.dims == output_dims for ftype in (lower.gt_type, upper.gt_type)) # type: ignore[union-attr]
343375

344-
lower_range_0 = output_domain[concat_dim_index][1]
376+
lower_range_0 = output_domain[concat_dim_index].start
345377
lower_range_1 = dace.symbolic.pystr_to_symbolic(
346-
f"max({lower_range_0}, {lower_domain[concat_dim_index][2]})"
378+
f"max({lower_range_0}, {lower_domain[concat_dim_index].stop})"
347379
)
348380
lower_range_size = lower_range_1 - lower_range_0
349381

350-
upper_range_1 = output_domain[concat_dim_index][2]
382+
upper_range_1 = output_domain[concat_dim_index].stop
351383
upper_range_0 = dace.symbolic.pystr_to_symbolic(
352-
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
384+
f"min({upper_range_1}, {upper_domain[concat_dim_index].start})"
353385
)
354386
upper_range_size = upper_range_1 - upper_range_0
355387

@@ -391,15 +423,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
391423
else:
392424
lower_subset.append(
393425
(
394-
output_domain[dim_index][1] - lower.origin[dim_index],
395-
output_domain[dim_index][1] - lower.origin[dim_index] + size - 1,
426+
output_domain[dim_index].start - lower.origin[dim_index],
427+
output_domain[dim_index].start - lower.origin[dim_index] + size - 1,
396428
1,
397429
)
398430
)
399431
upper_subset.append(
400432
(
401-
output_domain[dim_index][1] - upper.origin[dim_index],
402-
output_domain[dim_index][1] - upper.origin[dim_index] + size - 1,
433+
output_domain[dim_index].start - upper.origin[dim_index],
434+
output_domain[dim_index].start - upper.origin[dim_index] + size - 1,
403435
1,
404436
)
405437
)

src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,10 @@ def _create_field_operator(
201201
else:
202202
# create map range corresponding to the field operator domain
203203
map_range = {
204-
gtir_to_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
205-
for dim, lower_bound, upper_bound in domain
204+
gtir_to_sdfg_utils.get_map_variable(
205+
domain_range.dim
206+
): f"{domain_range.start}:{domain_range.stop}"
207+
for domain_range in domain
206208
}
207209
map_entry, map_exit = sdfg_builder.add_map("fieldop", state, map_range)
208210

@@ -511,8 +513,7 @@ def translate_index(
511513
assert "domain" in node.annex
512514
domain = gtir_domain.extract_domain(node.annex.domain)
513515
assert len(domain) == 1
514-
dim, _, _ = domain[0]
515-
dim_index = gtir_to_sdfg_utils.get_map_variable(dim)
516+
dim_index = gtir_to_sdfg_utils.get_map_variable(domain[0].dim)
516517

517518
index_data, _ = sdfg_builder.add_temp_scalar(sdfg, gtir_to_sdfg_types.INDEX_DTYPE)
518519
index_node = state.add_access(index_data)

src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,11 @@ def _create_scan_field_operator(
210210
"fieldop",
211211
state,
212212
ndrange={
213-
gtir_to_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
214-
for dim, lower_bound, upper_bound in domain
215-
if not sdfg_builder.is_column_axis(dim)
213+
gtir_to_sdfg_utils.get_map_variable(
214+
domain_range.dim
215+
): f"{domain_range.start}:{domain_range.stop}"
216+
for domain_range in domain
217+
if not sdfg_builder.is_column_axis(domain_range.dim)
216218
},
217219
)
218220

@@ -329,22 +331,18 @@ def _lower_lambda_to_nested_sdfg(
329331
)
330332

331333
# use the vertical dimension in the domain as scan dimension
332-
scan_domain = [
333-
(dim, lower_bound, upper_bound)
334-
for dim, lower_bound, upper_bound in domain
335-
if sdfg_builder.is_column_axis(dim)
336-
]
337-
assert len(scan_domain) == 1
338-
scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0]
334+
scan_domain = next(
335+
domain_range for domain_range in domain if sdfg_builder.is_column_axis(domain_range.dim)
336+
)
339337

340338
# extract the scan loop range
341-
scan_loop_var = gtir_to_sdfg_utils.get_map_variable(scan_dim)
339+
scan_loop_var = gtir_to_sdfg_utils.get_map_variable(scan_domain.dim)
342340

343341
# in case the scan operator computes a list (not a scalar), we need to add an extra dimension
344342
def get_scan_output_shape(
345343
scan_init_data: gtir_to_sdfg_types.FieldopData,
346344
) -> list[dace.symbolic.SymExpr]:
347-
scan_column_size = scan_upper_bound - scan_lower_bound
345+
scan_column_size = scan_domain.stop - scan_domain.start
348346
if isinstance(scan_init_data.gt_type, ts.ScalarType):
349347
return [scan_column_size]
350348
assert isinstance(scan_init_data.gt_type, ts.ListType)
@@ -391,18 +389,18 @@ def init_scan_carry(sym: gtir.Sym) -> None:
391389
if scan_forward:
392390
scan_loop = dace.sdfg.state.LoopRegion(
393391
label="scan",
394-
condition_expr=f"{scan_loop_var} < {scan_upper_bound}",
392+
condition_expr=f"{scan_loop_var} < {scan_domain.stop}",
395393
loop_var=scan_loop_var,
396-
initialize_expr=f"{scan_loop_var} = {scan_lower_bound}",
394+
initialize_expr=f"{scan_loop_var} = {scan_domain.start}",
397395
update_expr=f"{scan_loop_var} = {scan_loop_var} + 1",
398396
inverted=False,
399397
)
400398
else:
401399
scan_loop = dace.sdfg.state.LoopRegion(
402400
label="scan",
403-
condition_expr=f"{scan_loop_var} >= {scan_lower_bound}",
401+
condition_expr=f"{scan_loop_var} >= {scan_domain.start}",
404402
loop_var=scan_loop_var,
405-
initialize_expr=f"{scan_loop_var} = {scan_upper_bound} - 1",
403+
initialize_expr=f"{scan_loop_var} = {scan_domain.stop} - 1",
406404
update_expr=f"{scan_loop_var} = {scan_loop_var} - 1",
407405
inverted=False,
408406
)
@@ -431,7 +429,7 @@ def init_scan_carry(sym: gtir.Sym) -> None:
431429
for edge in lambda_input_edges:
432430
edge.connect(map_entry=None)
433431
# connect the dataflow output nodes, called 'scan_result' below, to a global field called 'output'
434-
output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_lower_bound
432+
output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_domain.start
435433

436434
def connect_scan_output(
437435
scan_output_edge: gtir_dataflow.DataflowOutputEdge,
@@ -475,8 +473,8 @@ def connect_scan_output(
475473
dace.Memlet.from_array(scan_result_data, scan_result_desc),
476474
)
477475

478-
output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype)
479-
return gtir_to_sdfg_types.FieldopData(output_node, output_type, origin=(scan_lower_bound,))
476+
output_type = ts.FieldType(dims=[scan_domain.dim], dtype=scan_result.gt_dtype)
477+
return gtir_to_sdfg_types.FieldopData(output_node, output_type, origin=(scan_domain.start,))
480478

481479
# write the stencil result (value on one vertical level) into a 1D field
482480
# with full vertical shape representing one column

0 commit comments

Comments
 (0)