@@ -185,20 +185,20 @@ def _translate_concat_where_impl(
185185 gtir_domain .extract_domain (domain ) for domain in [tb_node_domain , fb_node_domain ]
186186 )
187187 assert len (mask_domain ) == 1
188- concat_dim , mask_lower_bound , mask_upper_bound = mask_domain [0 ]
188+ concat_domain = mask_domain [0 ]
189189
190190 # Expect unbound range in the concat domain expression on lower or upper range:
191191 # - if the domain expression is unbound on lower side (negative infinite),
192192 # the expression on the true branch is to be considered the input for the
193193 # lower domain.
194194 # - viceversa, if the domain expression is unbound on upper side (positive
195195 # infinite), the true expression represents the input for the upper domain.
196- if mask_lower_bound == gtir_to_sdfg_utils .get_symbolic (gtir .InfinityLiteral .NEGATIVE ):
197- concat_dim_bound = mask_upper_bound
196+ if concat_domain . start == gtir_to_sdfg_utils .get_symbolic (gtir .InfinityLiteral .NEGATIVE ):
197+ concat_dim_bound = concat_domain . stop
198198 lower , lower_desc , lower_domain = (tb_field , tb_data_desc , tb_domain )
199199 upper , upper_desc , upper_domain = (fb_field , fb_data_desc , fb_domain )
200- elif mask_upper_bound == gtir_to_sdfg_utils .get_symbolic (gtir .InfinityLiteral .POSITIVE ):
201- concat_dim_bound = mask_lower_bound
200+ elif concat_domain . stop == gtir_to_sdfg_utils .get_symbolic (gtir .InfinityLiteral .POSITIVE ):
201+ concat_dim_bound = concat_domain . start
202202 lower , lower_desc , lower_domain = (fb_field , fb_data_desc , fb_domain )
203203 upper , upper_desc , upper_domain = (tb_field , tb_data_desc , tb_domain )
204204 else :
@@ -207,9 +207,9 @@ def _translate_concat_where_impl(
207207 # we use the concat domain, stored in the annex, as the domain of output field
208208 output_domain = gtir_domain .extract_domain (node_domain )
209209 output_dims , output_origin , output_shape = _get_concat_where_field_layout (
210- output_domain , concat_dim
210+ output_domain , concat_domain . dim
211211 )
212- concat_dim_index = output_dims .index (concat_dim )
212+ concat_dim_index = output_dims .index (concat_domain . dim )
213213
214214 """
215215 In case one of the arguments is a scalar value, for example:
@@ -225,23 +225,27 @@ def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField:
225225 assert isinstance (upper .gt_type , ts .FieldType )
226226 lower = gtir_to_sdfg_types .FieldopData (
227227 lower .dc_node ,
228- ts .FieldType (dims = [concat_dim ], dtype = lower .gt_type ),
228+ ts .FieldType (dims = [concat_domain . dim ], dtype = lower .gt_type ),
229229 origin = (concat_dim_bound - 1 ,),
230230 )
231- lower_bound = output_domain [concat_dim_index ][1 ]
232- lower_domain = [(concat_dim , lower_bound , concat_dim_bound )]
231+ lower_bound = output_domain [concat_dim_index ].start
232+ lower_domain = [
233+ gtir_domain .FieldopDomainRange (concat_domain .dim , lower_bound , concat_dim_bound )
234+ ]
233235 elif isinstance (upper .gt_type , ts .ScalarType ):
234236 assert len (upper_domain ) == 0
235237 assert isinstance (lower .gt_type , ts .FieldType )
236238 upper = gtir_to_sdfg_types .FieldopData (
237239 upper .dc_node ,
238- ts .FieldType (dims = [concat_dim ], dtype = upper .gt_type ),
240+ ts .FieldType (dims = [concat_domain . dim ], dtype = upper .gt_type ),
239241 origin = (concat_dim_bound ,),
240242 )
241- upper_bound = output_domain [concat_dim_index ][2 ]
242- upper_domain = [(concat_dim , concat_dim_bound , upper_bound )]
243+ upper_bound = output_domain [concat_dim_index ].stop
244+ upper_domain = [
245+ gtir_domain .FieldopDomainRange (concat_domain .dim , concat_dim_bound , upper_bound )
246+ ]
243247
244- if concat_dim not in lower .gt_type .dims : # type: ignore[union-attr]
248+ if concat_domain . dim not in lower .gt_type .dims : # type: ignore[union-attr]
245249 """
246250 The field on the lower domain is to be treated as a slice to add as one
247251 level in the concat dimension, on the lower bound.
@@ -261,13 +265,22 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
261265 ]
262266 )
263267 lower , lower_desc = _make_concat_field_slice (
264- sdfg , state , lower , lower_desc , concat_dim , concat_dim_index , concat_dim_bound - 1
268+ sdfg = sdfg ,
269+ state = state ,
270+ field = lower ,
271+ field_desc = lower_desc ,
272+ concat_dim = concat_domain .dim ,
273+ concat_dim_index = concat_dim_index ,
274+ concat_dim_origin = concat_dim_bound - 1 ,
265275 )
266276 lower_bound = dace .symbolic .pystr_to_symbolic (
267- f"max({ concat_dim_bound - 1 } , { output_domain [concat_dim_index ][1 ]} )"
277+ f"max({ concat_dim_bound - 1 } , { output_domain [concat_dim_index ].start } )"
278+ )
279+ lower_domain .insert (
280+ concat_dim_index ,
281+ gtir_domain .FieldopDomainRange (concat_domain .dim , lower_bound , concat_dim_bound ),
268282 )
269- lower_domain .insert (concat_dim_index , (concat_dim , lower_bound , concat_dim_bound ))
270- elif concat_dim not in upper .gt_type .dims : # type: ignore[union-attr]
283+ elif concat_domain .dim not in upper .gt_type .dims : # type: ignore[union-attr]
271284 # Same as previous case, but the field slice is added on the upper bound.
272285 assert (
273286 upper .gt_type .dims # type: ignore[union-attr]
@@ -277,12 +290,21 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
277290 ]
278291 )
279292 upper , upper_desc = _make_concat_field_slice (
280- sdfg , state , upper , upper_desc , concat_dim , concat_dim_index , concat_dim_bound
293+ sdfg = sdfg ,
294+ state = state ,
295+ field = upper ,
296+ field_desc = upper_desc ,
297+ concat_dim = concat_domain .dim ,
298+ concat_dim_index = concat_dim_index ,
299+ concat_dim_origin = concat_dim_bound ,
281300 )
282301 upper_bound = dace .symbolic .pystr_to_symbolic (
283- f"min({ concat_dim_bound + 1 } , { output_domain [concat_dim_index ][2 ]} )"
302+ f"min({ concat_dim_bound + 1 } , { output_domain [concat_dim_index ].stop } )"
303+ )
304+ upper_domain .insert (
305+ concat_dim_index ,
306+ gtir_domain .FieldopDomainRange (concat_domain .dim , concat_dim_bound , upper_bound ),
284307 )
285- upper_domain .insert (concat_dim_index , (concat_dim , concat_dim_bound , upper_bound ))
286308 elif isinstance (lower_desc , dace .data .Scalar ) or (
287309 len (lower .gt_type .dims ) == 1 and len (output_domain ) > 1 # type: ignore[union-attr]
288310 ):
@@ -297,27 +319,37 @@ def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField:
297319 return concat_where(KDim == 0, a, b)
298320 ```
299321 """
300- assert len (lower_domain ) == 1 and lower_domain [0 ][ 0 ] == concat_dim
322+ assert len (lower_domain ) == 1 and lower_domain [0 ]. dim == concat_domain . dim
301323 lower_domain = [
302324 * output_domain [:concat_dim_index ],
303325 lower_domain [0 ],
304326 * output_domain [concat_dim_index + 1 :],
305327 ]
306328 lower , lower_desc = _make_concat_scalar_broadcast (
307- sdfg , state , lower , lower_desc , lower_domain , concat_dim_index
329+ sdfg = sdfg ,
330+ state = state ,
331+ inp = lower ,
332+ inp_desc = lower_desc ,
333+ domain = lower_domain ,
334+ concat_dim_index = concat_dim_index ,
308335 )
309336 elif isinstance (upper_desc , dace .data .Scalar ) or (
310337 len (upper .gt_type .dims ) == 1 and len (output_domain ) > 1 # type: ignore[union-attr]
311338 ):
312339 # Same as previous case, but the scalar value is taken from `upper` input.
313- assert len (upper_domain ) == 1 and upper_domain [0 ][ 0 ] == concat_dim
340+ assert len (upper_domain ) == 1 and upper_domain [0 ]. dim == concat_domain . dim
314341 upper_domain = [
315342 * output_domain [:concat_dim_index ],
316343 upper_domain [0 ],
317344 * output_domain [concat_dim_index + 1 :],
318345 ]
319346 upper , upper_desc = _make_concat_scalar_broadcast (
320- sdfg , state , upper , upper_desc , upper_domain , concat_dim_index
347+ sdfg = sdfg ,
348+ state = state ,
349+ inp = upper ,
350+ inp_desc = upper_desc ,
351+ domain = upper_domain ,
352+ concat_dim_index = concat_dim_index ,
321353 )
322354 else :
323355 """
@@ -341,15 +373,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
341373 # ensure that the arguments have the same domain as the concat result
342374 assert all (ftype .dims == output_dims for ftype in (lower .gt_type , upper .gt_type )) # type: ignore[union-attr]
343375
344- lower_range_0 = output_domain [concat_dim_index ][ 1 ]
376+ lower_range_0 = output_domain [concat_dim_index ]. start
345377 lower_range_1 = dace .symbolic .pystr_to_symbolic (
346- f"max({ lower_range_0 } , { lower_domain [concat_dim_index ][ 2 ] } )"
378+ f"max({ lower_range_0 } , { lower_domain [concat_dim_index ]. stop } )"
347379 )
348380 lower_range_size = lower_range_1 - lower_range_0
349381
350- upper_range_1 = output_domain [concat_dim_index ][ 2 ]
382+ upper_range_1 = output_domain [concat_dim_index ]. stop
351383 upper_range_0 = dace .symbolic .pystr_to_symbolic (
352- f"min({ upper_range_1 } , { upper_domain [concat_dim_index ][ 1 ] } )"
384+ f"min({ upper_range_1 } , { upper_domain [concat_dim_index ]. start } )"
353385 )
354386 upper_range_size = upper_range_1 - upper_range_0
355387
@@ -391,15 +423,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
391423 else :
392424 lower_subset .append (
393425 (
394- output_domain [dim_index ][ 1 ] - lower .origin [dim_index ],
395- output_domain [dim_index ][ 1 ] - lower .origin [dim_index ] + size - 1 ,
426+ output_domain [dim_index ]. start - lower .origin [dim_index ],
427+ output_domain [dim_index ]. start - lower .origin [dim_index ] + size - 1 ,
396428 1 ,
397429 )
398430 )
399431 upper_subset .append (
400432 (
401- output_domain [dim_index ][ 1 ] - upper .origin [dim_index ],
402- output_domain [dim_index ][ 1 ] - upper .origin [dim_index ] + size - 1 ,
433+ output_domain [dim_index ]. start - upper .origin [dim_index ],
434+ output_domain [dim_index ]. start - upper .origin [dim_index ] + size - 1 ,
403435 1 ,
404436 )
405437 )
0 commit comments