@@ -752,7 +752,10 @@ def concatenate_inputs(
752752 ts .FieldType (dims = [concat_dim ], dtype = lower .gt_type ),
753753 origin = (concat_dim_bound - 1 ,),
754754 )
755- lower_domain = [(concat_dim , concat_dim_bound - 1 , concat_dim_bound )]
755+ lower_bound = dace .symbolic .pystr_to_symbolic (
756+ f"max({ concat_dim_bound - 1 } , { output_domain [concat_dim_index ][1 ]} )"
757+ )
758+ lower_domain = [(concat_dim , lower_bound , concat_dim_bound )]
756759 elif isinstance (upper .gt_type , ts .ScalarType ):
757760 assert len (upper_domain ) == 0
758761 assert isinstance (lower .gt_type , ts .FieldType )
@@ -761,7 +764,10 @@ def concatenate_inputs(
761764 ts .FieldType (dims = [concat_dim ], dtype = upper .gt_type ),
762765 origin = (concat_dim_bound ,),
763766 )
764- upper_domain = [(concat_dim , concat_dim_bound , concat_dim_bound + 1 )]
767+ upper_bound = dace .symbolic .pystr_to_symbolic (
768+ f"min({ concat_dim_bound + 1 } , { output_domain [concat_dim_index ][2 ]} )"
769+ )
770+ upper_domain = [(concat_dim , concat_dim_bound , upper_bound )]
765771
766772 is_lower_slice , is_upper_slice = (False , False )
767773 if concat_dim not in lower .gt_type .dims :
0 commit comments