1616
1717from gt4py .next import common as gtx_common , utils as gtx_utils
1818from gt4py .next .iterator import ir as gtir
19- from gt4py .next .iterator .ir_utils import common_pattern_matcher as cpm , domain_utils
19+ from gt4py .next .iterator .ir_utils import (
20+ common_pattern_matcher as cpm ,
21+ domain_utils ,
22+ ir_makers as im ,
23+ )
2024from gt4py .next .program_processors .runners .dace import (
2125 gtir_dataflow ,
2226 gtir_domain ,
@@ -247,10 +251,16 @@ def translate_as_fieldop(
247251 raise NotImplementedError ("Unexpected 'as_filedop' with tuple output in SDFG lowering." )
248252
249253 if cpm .is_ref_to (fieldop_expr , "deref" ):
250- # Special usage of 'deref' as argument to fieldop expression, to pass a scalar
251- # value to 'as_fieldop' function. It results in broadcasting the scalar value
252- # over the field domain.
253- return translate_broadcast (node , ctx , sdfg_builder )
254+ if isinstance (node .args [0 ].type , ts .ScalarType ):
255+ # Special usage of 'deref' as argument to fieldop expression, to broadcast
256+ # a scalar value on the field domain.
257+ return translate_broadcast (node , ctx , sdfg_builder )
258+ else :
259+ # Special usage of 'deref' with field argument, to return a subset of
260+ # the full field domain.
261+ # TODO(edopao): Lower this case to a memlet edge, planned for next PR.
262+ stencil_expr = im .lambda_ ("a" )(im .deref ("a" ))
263+ stencil_expr .expr .type = node .type .dtype
254264 elif isinstance (fieldop_expr , gtir .Lambda ):
255265 # Default case, handled below: the argument expression is a lambda function
256266 # representing the stencil operation to be computed over the field domain.
@@ -285,6 +295,7 @@ def translate_broadcast(
285295 ctx : gtir_to_sdfg .SubgraphContext ,
286296 sdfg_builder : gtir_to_sdfg .SDFGBuilder ,
287297) -> gtir_to_sdfg_types .FieldopData :
298+ """Translates a broadcast expression which writes a scalar value on the field domain."""
288299 assert isinstance (node , gtir .FunCall )
289300 assert cpm .is_call_to (node .fun , "as_fieldop" )
290301
@@ -294,6 +305,10 @@ def translate_broadcast(
294305 assert isinstance (node .type .dtype , ts .ScalarType )
295306 field_dtype = gtx_dace_utils .as_dace_type (node .type .dtype )
296307
308+ assert len (node .args ) == 1
309+ assert isinstance (node .args [0 ].type , ts .ScalarType )
310+ scalar_arg = node .args [0 ]
311+
297312 fun_node = node .fun
298313 assert len (fun_node .args ) == 2
299314 fieldop_expr , fieldop_domain_expr = fun_node .args
@@ -308,32 +323,23 @@ def translate_broadcast(
308323 # The memory layout of the output field follows the field operator compute domain.
309324 field_dims , field_origin , field_shape = gtir_domain .get_field_layout (field_domain )
310325 assert field_dims == node .type .dims
311-
312326 field_name , field_desc = sdfg_builder .add_temp_array (ctx .sdfg , field_shape , field_dtype )
313327 field_node = ctx .state .add_access (field_name )
314328
315329 # Retrieve the scalar argument, which could be either a literal value or the
316330 # result of a scalar expression.
317- assert len (node .args ) == 1
331+ arg = _parse_fieldop_arg (scalar_arg , ctx , sdfg_builder , field_domain )
332+ assert isinstance (arg , gtir_dataflow .MemletExpr )
333+ assert arg .subset .num_elements () == 1
318334
319335 # Use a 'Fill' library node to write the scalar value to the result field.
320- if isinstance (node .args [0 ], gtir .Literal ):
321- assert node .args [0 ].type == node .type .dtype
322- value = field_dtype (node .args [0 ].value )
323- fill_node = sdfg_library_nodes .Fill ("fill" , value )
324- ctx .state .add_node (fill_node )
325- else :
326- arg = _parse_fieldop_arg (node .args [0 ], ctx , sdfg_builder , field_domain )
327- assert isinstance (arg , gtir_dataflow .MemletExpr )
328- assert arg .subset .num_elements () == 1
329-
330- fill_node = sdfg_library_nodes .Fill ("fill" )
331- ctx .state .add_node (fill_node )
332- ctx .state .add_nedge (
333- arg .dc_node ,
334- fill_node ,
335- dace .Memlet (data = arg .dc_node .data , subset = arg .subset ),
336- )
336+ fill_node = sdfg_library_nodes .Fill ("fill" )
337+ ctx .state .add_node (fill_node )
338+ ctx .state .add_nedge (
339+ arg .dc_node ,
340+ fill_node ,
341+ dace .Memlet (data = arg .dc_node .data , subset = arg .subset ),
342+ )
337343
338344 ctx .state .add_nedge (
339345 fill_node ,
0 commit comments