Skip to content

Commit bc9425e

Browse files
committed
fix domain in scalar broadcast
1 parent 8a4308b commit bc9425e

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,10 @@ def concatenate_inputs(
752752
ts.FieldType(dims=[concat_dim], dtype=lower.gt_type),
753753
origin=(concat_dim_bound - 1,),
754754
)
755-
lower_domain = [(concat_dim, concat_dim_bound - 1, concat_dim_bound)]
755+
lower_bound = dace.symbolic.pystr_to_symbolic(
756+
f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index][1]})"
757+
)
758+
lower_domain = [(concat_dim, lower_bound, concat_dim_bound)]
756759
elif isinstance(upper.gt_type, ts.ScalarType):
757760
assert len(upper_domain) == 0
758761
assert isinstance(lower.gt_type, ts.FieldType)
@@ -761,7 +764,10 @@ def concatenate_inputs(
761764
ts.FieldType(dims=[concat_dim], dtype=upper.gt_type),
762765
origin=(concat_dim_bound,),
763766
)
764-
upper_domain = [(concat_dim, concat_dim_bound, concat_dim_bound + 1)]
767+
upper_bound = dace.symbolic.pystr_to_symbolic(
768+
f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index][2]})"
769+
)
770+
upper_domain = [(concat_dim, concat_dim_bound, upper_bound)]
765771

766772
is_lower_slice, is_upper_slice = (False, False)
767773
if concat_dim not in lower.gt_type.dims:

0 commit comments

Comments
 (0)