Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,12 @@ def unparse_cr(sdfg, wcr_ast, dtype):
def connected_to_gpu_memory(node: nodes.Node, state: SDFGState, sdfg: SDFG):
for e in state.all_edges(node):
path = state.memlet_path(e)
if ((isinstance(path[0].src, nodes.AccessNode)
and path[0].src.desc(sdfg).storage is dtypes.StorageType.GPU_Global)):
if (((isinstance(path[0].src, nodes.AccessNode)
and path[0].src.desc(sdfg).storage is dtypes.StorageType.GPU_Global))
or ((isinstance(path[-1].dst, nodes.AccessNode)
and path[-1].dst.desc(sdfg).storage is dtypes.StorageType.GPU_Global))):
return True

return False


Expand Down
8 changes: 6 additions & 2 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class ScheduleType(ExtensibleAttributeEnum):
StorageType.GPU_Shared,
]

GPU_KERNEL_ACCESSIBLE_STORAGES = [StorageType.GPU_Global, StorageType.GPU_Shared, StorageType.Register]


class ReductionType(Enum):
""" Reduction types natively supported by the SDFG compiler. """
Expand Down Expand Up @@ -176,7 +178,7 @@ class TilingType(Enum):
ScheduleType.GPU_ThreadBlock: StorageType.Register,
ScheduleType.GPU_ThreadBlock_Dynamic: StorageType.Register,
ScheduleType.SVE_Map: StorageType.CPU_Heap,
ScheduleType.Snitch: StorageType.Snitch_TCDM
ScheduleType.Snitch: StorageType.Snitch_TCDM,
}

# Maps from ScheduleType to default ScheduleType for sub-scopes
Expand All @@ -193,7 +195,7 @@ class TilingType(Enum):
ScheduleType.GPU_ThreadBlock_Dynamic: ScheduleType.Sequential,
ScheduleType.SVE_Map: ScheduleType.Sequential,
ScheduleType.Snitch: ScheduleType.Snitch,
ScheduleType.Snitch_Multicore: ScheduleType.Snitch_Multicore
ScheduleType.Snitch_Multicore: ScheduleType.Snitch_Multicore,
}

# Maps from StorageType to a preferred ScheduleType for helping determine schedules.
Expand Down Expand Up @@ -1186,6 +1188,7 @@ class complex128(_DaCeArray, npt.NDArray[numpy.complex128]): ...
class string(_DaCeArray, npt.NDArray[numpy.str_]): ...
class vector(_DaCeArray, npt.NDArray[numpy.void]): ...
class MPI_Request(_DaCeArray, npt.NDArray[numpy.void]): ...
class gpuStream_t(_DaCeArray, npt.NDArray[numpy.void]): ...
# yapf: enable
else:
# Runtime definitions
Expand All @@ -1206,6 +1209,7 @@ class MPI_Request(_DaCeArray, npt.NDArray[numpy.void]): ...
complex128 = typeclass(numpy.complex128)
string = stringtype()
MPI_Request = opaque('MPI_Request')
gpuStream_t = opaque('gpuStream_t')

_bool = bool

Expand Down
Empty file.
Empty file.
Loading
Loading