Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,29 @@ class Device:
total = system.get_num_devices()
return tuple(cls(device_id) for device_id in range(total))

def to_system_device(self) -> 'cuda.core.system.Device':
"""
Get the corresponding :class:`cuda.core.system.Device` (which is used
for NVIDIA Machine Library (NVML) access) for this
:class:`cuda.core.Device` (which is used for CUDA access).

The devices are mapped to one another by their UUID.

Returns
-------
cuda.core.system.Device
The corresponding system-level device instance used for NVML access.
"""
from cuda.core.system._system import CUDA_BINDINGS_NVML_IS_COMPATIBLE

if not CUDA_BINDINGS_NVML_IS_COMPATIBLE:
raise RuntimeError(
"cuda.core.system.Device requires cuda_bindings 13.1.2+ or 12.9.6+"
)

from cuda.core.system import Device as SystemDevice
return SystemDevice(uuid=self.uuid)

@property
def device_id(self) -> int:
"""Return device ordinal."""
Expand Down
30 changes: 30 additions & 0 deletions cuda_core/cuda/core/system/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,36 @@ cdef class Device:
pci_bus_id = pci_bus_id.decode("ascii")
self._handle = nvml.device_get_handle_by_pci_bus_id_v2(pci_bus_id)

def to_cuda_device(self) -> "cuda.core.Device":
"""
Get the corresponding :class:`cuda.core.Device` (which is used for CUDA
access) for this :class:`cuda.core.system.Device` (which is used for
NVIDIA machine library (NVML) access).

The devices are mapped to one another by their UUID.

Returns
-------
cuda.core.Device
The corresponding CUDA device.
"""
from cuda.core import Device as CudaDevice

# CUDA does not have an API to get a device by its UUID, so we just
# search all the devices for one with a matching UUID.

# NVML UUIDs have a `GPU-` or `MIG-` prefix. Possibly we should only do
# this matching when it has a `GPU-` prefix, but for now we just strip
# it. If a matching CUDA device can't be found, we will get a helpful
# exception, anyway, below.
uuid = self.uuid[4:]

for cuda_device in CudaDevice.get_all_devices():
if cuda_device.uuid == uuid:
return cuda_device

raise RuntimeError("No corresponding CUDA device found for this NVML device.")

@classmethod
def get_device_count(cls) -> int:
"""
Expand Down
17 changes: 17 additions & 0 deletions cuda_core/tests/system/test_system_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ def test_device_count():
assert system.Device.get_device_count() == system.get_num_devices()


def test_to_cuda_device():
from cuda.core import Device as CudaDevice

for device in system.Device.get_all_devices():
cuda_device = device.to_cuda_device()

assert isinstance(cuda_device, CudaDevice)
assert cuda_device.uuid == device.uuid[4:]

# Technically, this test will only work with PCI devices, but are there
# non-PCI devices we need to support?

# CUDA only returns a 2-byte PCI bus ID domain, whereas NVML returns a
# 4-byte domain
assert cuda_device.pci_bus_id == device.pci_info.bus_id[4:]


def test_device_architecture():
for device in system.Device.get_all_devices():
device_arch = device.architecture
Expand Down
24 changes: 24 additions & 0 deletions cuda_core/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ def cuda_version():
return _py_major_ver, _driver_ver


def test_to_system_device(deinit_cuda):
from cuda.core.system import _system

device = Device()

if not _system.CUDA_BINDINGS_NVML_IS_COMPATIBLE:
with pytest.raises(RuntimeError):
device.to_system_device()
pytest.skip("NVML support requires cuda.bindings version 12.9.6+ or 13.1.2+")

from cuda.core.system import Device as SystemDevice

system_device = device.to_system_device()
assert isinstance(system_device, SystemDevice)
assert system_device.uuid[4:] == device.uuid

# Technically, this test will only work with PCI devices, but are there
# non-PCI devices we need to support?

# CUDA only returns a 2-byte PCI bus ID domain, whereas NVML returns a
# 4-byte domain
assert device.pci_bus_id == system_device.pci_info.bus_id[4:]


def test_device_set_current(deinit_cuda):
device = Device()
device.set_current()
Expand Down
Loading