Skip to content

Commit 178037a

Browse files
Squashed commit of the following:
commit 02327a9 Author: Philip Mueller <philip.mueller@cscs.ch> Date: Thu Jul 3 09:46:03 2025 +0200 Optimized `gpu_map_has_explicit_threadblocks()` the function will only once fully scann the tree. commit 6d13e68 Author: Philip Mueller <philip.mueller@cscs.ch> Date: Thu Jul 3 09:36:16 2025 +0200 Small modification that will improve debug experience. commit 8f4cbc5 Author: Philip Mueller <philip.mueller@cscs.ch> Date: Thu Jul 3 08:51:57 2025 +0200 Added a documentation. commit edeca0e Author: Philip Mueller <philip.mueller@cscs.ch> Date: Thu Jul 3 08:17:48 2025 +0200 This kind of works. commit a621c44 Author: Philip Mueller <philip.mueller@cscs.ch> Date: Wed Jul 2 12:53:15 2025 +0200 This shjould Fix Edoardo's case, but a test is missing.
1 parent 0deba99 commit 178037a

1 file changed

Lines changed: 32 additions & 8 deletions

File tree

dace/transformation/helpers.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,40 +1375,62 @@ def scope_tree_recursive(state: SDFGState, entry: Optional[nodes.EntryNode] = No
13751375
:param state: The state that contains the root of the scope tree.
13761376
:param entry: A scope entry node to set as root, otherwise the state is
13771377
the root if None is given.
1378+
1379+
:note: This function adds a `state` attribute to the `ScopeTree` objects, it refers to
1380+
the state to which the scope was found.
13781381
"""
1382+
# The first clear is to make sure that the data structure we get is not referred to by any
1383+
# other reference. The second clear is needed to ensure that nobody gets the object we
1384+
# operate on. This is because we actually modifying them and they are cached by the state.
1385+
# The real error happens if this function is called multiple times, without a refresh.
1386+
# NOTE: `scope_tree()` only performs a shallow copy, but only of the `dict` that we do not use.
1387+
state._clear_scopedict_cache()
13791388
stree = state.scope_tree()[entry]
1389+
state._clear_scopedict_cache()
1390+
13801391
stree.state = state # Annotate state in tree
13811392

13821393
# Add nested SDFGs as children
13831394
def traverse(state: SDFGState, treenode: ScopeTree):
1395+
# See above why.
1396+
state._clear_scopedict_cache()
13841397
snodes = state.scope_children()[treenode.entry]
1398+
state._clear_scopedict_cache()
1399+
13851400
for node in snodes:
13861401
if isinstance(node, nodes.NestedSDFG):
13871402
for nstate in node.sdfg.states():
13881403
ntree = nstate.scope_tree()[None]
1404+
assert ntree not in treenode.children
1405+
assert not hasattr(ntree, "state") # Non standard field.
13891406
ntree.state = nstate
13901407
treenode.children.append(ntree)
1408+
13911409
for child in treenode.children:
13921410
if hasattr(child, 'state') and child.state != state:
1393-
traverse(getattr(child, 'state', state), child)
1411+
traverse(child.state, child)
13941412

13951413
traverse(state, stree)
13961414
return stree
13971415

13981416

13991417
def get_internal_scopes(state: SDFGState,
14001418
entry: nodes.EntryNode,
1401-
immediate: bool = False) -> List[Tuple[SDFGState, nodes.EntryNode]]:
1419+
immediate: bool = False,
1420+
recursive_scope_tree: Optional[ScopeTree] = None) -> List[Tuple[SDFGState, nodes.EntryNode]]:
14021421
"""
14031422
Returns all internal scopes within a given scope, including if they
14041423
reside in nested SDFGs.
14051424
14061425
:param state: State in which entry node resides.
14071426
:param entry: The entry node to start from.
1408-
:param immediate: If True, only returns the scopes that are immediately
1409-
nested in the map.
1427+
:param immediate: If True, only returns the scopes that are immediately nested in the map.
1428+
:param recursive_scope_tree: The recursive scope tree, see `scope_tree_recursive()` for more.
14101429
"""
1411-
stree = scope_tree_recursive(state, entry)
1430+
1431+
if recursive_scope_tree is None:
1432+
recursive_scope_tree = scope_tree_recursive(state, entry)
1433+
14121434
result = []
14131435

14141436
def traverse(state: SDFGState, treenode: ScopeTree):
@@ -1420,19 +1442,21 @@ def traverse(state: SDFGState, treenode: ScopeTree):
14201442
else: # Nested SDFG
14211443
traverse(child.state, child)
14221444

1423-
traverse(state, stree)
1445+
traverse(state, recursive_scope_tree)
14241446
return result
14251447

14261448

14271449
def gpu_map_has_explicit_threadblocks(state: SDFGState, entry: nodes.EntryNode) -> bool:
14281450
"""
14291451
Returns True if GPU_Device map has explicit thread-block maps nested within.
14301452
"""
1431-
internal_maps = get_internal_scopes(state, entry)
1453+
rstree = scope_tree_recursive(state, entry)
1454+
internal_maps = get_internal_scopes(state, entry, recursive_scope_tree=rstree)
14321455
if any(m.schedule in (dtypes.ScheduleType.GPU_ThreadBlock, dtypes.ScheduleType.GPU_ThreadBlock_Dynamic)
14331456
for _, m in internal_maps):
14341457
return True
1435-
imm_maps = get_internal_scopes(state, entry, immediate=True)
1458+
1459+
imm_maps = get_internal_scopes(state, entry, immediate=True, recursive_scope_tree=rstree)
14361460
if any(m.schedule == dtypes.ScheduleType.Default for _, m in imm_maps):
14371461
return True
14381462

0 commit comments

Comments
 (0)