@@ -1375,40 +1375,62 @@ def scope_tree_recursive(state: SDFGState, entry: Optional[nodes.EntryNode] = No
13751375 :param state: The state that contains the root of the scope tree.
13761376 :param entry: A scope entry node to set as root, otherwise the state is
13771377 the root if None is given.
1378+
1379+ :note: This function adds a `state` attribute to the `ScopeTree` objects, it refers to
1380+ the state to which the scope was found.
13781381 """
1382+ # The first clear is to make sure that the data structure we get is not referred to by any
1383+ # other reference. The second clear is needed to ensure that nobody gets the object we
1384+ # operate on. This is because we actually modifying them and they are cached by the state.
1385+ # The real error happens if this function is called multiple times, without a refresh.
1386+ # NOTE: `scope_tree()` only performs a shallow copy, but only of the `dict` that we do not use.
1387+ state ._clear_scopedict_cache ()
13791388 stree = state .scope_tree ()[entry ]
1389+ state ._clear_scopedict_cache ()
1390+
13801391 stree .state = state # Annotate state in tree
13811392
13821393 # Add nested SDFGs as children
13831394 def traverse (state : SDFGState , treenode : ScopeTree ):
1395+ # See above why.
1396+ state ._clear_scopedict_cache ()
13841397 snodes = state .scope_children ()[treenode .entry ]
1398+ state ._clear_scopedict_cache ()
1399+
13851400 for node in snodes :
13861401 if isinstance (node , nodes .NestedSDFG ):
13871402 for nstate in node .sdfg .states ():
13881403 ntree = nstate .scope_tree ()[None ]
1404+ assert ntree not in treenode .children
1405+ assert not hasattr (ntree , "state" ) # Non standard field.
13891406 ntree .state = nstate
13901407 treenode .children .append (ntree )
1408+
13911409 for child in treenode .children :
13921410 if hasattr (child , 'state' ) and child .state != state :
1393- traverse (getattr ( child , ' state' , state ) , child )
1411+ traverse (child . state , child )
13941412
13951413 traverse (state , stree )
13961414 return stree
13971415
13981416
13991417def get_internal_scopes (state : SDFGState ,
14001418 entry : nodes .EntryNode ,
1401- immediate : bool = False ) -> List [Tuple [SDFGState , nodes .EntryNode ]]:
1419+ immediate : bool = False ,
1420+ recursive_scope_tree : Optional [ScopeTree ] = None ) -> List [Tuple [SDFGState , nodes .EntryNode ]]:
14021421 """
14031422 Returns all internal scopes within a given scope, including if they
14041423 reside in nested SDFGs.
14051424
14061425 :param state: State in which entry node resides.
14071426 :param entry: The entry node to start from.
1408- :param immediate: If True, only returns the scopes that are immediately
1409- nested in the map .
1427+ :param immediate: If True, only returns the scopes that are immediately nested in the map.
1428+ :param recursive_scope_tree: The recursive scope tree, see `scope_tree_recursive()` for more .
14101429 """
1411- stree = scope_tree_recursive (state , entry )
1430+
1431+ if recursive_scope_tree is None :
1432+ recursive_scope_tree = scope_tree_recursive (state , entry )
1433+
14121434 result = []
14131435
14141436 def traverse (state : SDFGState , treenode : ScopeTree ):
@@ -1420,19 +1442,21 @@ def traverse(state: SDFGState, treenode: ScopeTree):
14201442 else : # Nested SDFG
14211443 traverse (child .state , child )
14221444
1423- traverse (state , stree )
1445+ traverse (state , recursive_scope_tree )
14241446 return result
14251447
14261448
14271449def gpu_map_has_explicit_threadblocks (state : SDFGState , entry : nodes .EntryNode ) -> bool :
14281450 """
14291451 Returns True if GPU_Device map has explicit thread-block maps nested within.
14301452 """
1431- internal_maps = get_internal_scopes (state , entry )
1453+ rstree = scope_tree_recursive (state , entry )
1454+ internal_maps = get_internal_scopes (state , entry , recursive_scope_tree = rstree )
14321455 if any (m .schedule in (dtypes .ScheduleType .GPU_ThreadBlock , dtypes .ScheduleType .GPU_ThreadBlock_Dynamic )
14331456 for _ , m in internal_maps ):
14341457 return True
1435- imm_maps = get_internal_scopes (state , entry , immediate = True )
1458+
1459+ imm_maps = get_internal_scopes (state , entry , immediate = True , recursive_scope_tree = rstree )
14361460 if any (m .schedule == dtypes .ScheduleType .Default for _ , m in imm_maps ):
14371461 return True
14381462
0 commit comments