diff --git a/pyproject.toml b/pyproject.toml index f60732b..090e118 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ sdist.include = [ "src/gpubackendtools/cutils/gbt_global.h", "src/gpubackendtools/cutils/pybind11_cuda_array_interface.hpp", "src/gpubackendtools/cutils/cmake_functions.cmake", + "src/gpubackendtools/cutils/gbt_binding.hpp" ] sdist.exclude = [ ".devcontainer/", diff --git a/src/gpubackendtools/__init__.py b/src/gpubackendtools/__init__.py index 48631c6..daded45 100644 --- a/src/gpubackendtools/__init__.py +++ b/src/gpubackendtools/__init__.py @@ -36,12 +36,13 @@ from .parallelbase import ParallelModuleBase from .globals import Globals -from .cutils import GBTCpuBackend, GBTCuda11xBackend, GBTCuda12xBackend +from .cutils import GBTCpuBackend, GBTCuda11xBackend, GBTCuda12xBackend, GBTCuda13xBackend add_backends = { "gbt_cpu": GBTCpuBackend, "gbt_cuda11x": GBTCuda11xBackend, "gbt_cuda12x": GBTCuda12xBackend, + "gbt_cuda13x": GBTCuda13xBackend, } Globals().backends_manager.add_backends(add_backends) diff --git a/src/gpubackendtools/cutils/__init__.py b/src/gpubackendtools/cutils/__init__.py index 4322492..bed048d 100644 --- a/src/gpubackendtools/cutils/__init__.py +++ b/src/gpubackendtools/cutils/__init__.py @@ -6,7 +6,7 @@ import abc from typing import Optional, Sequence, TypeVar, Union -from ..gpubackendtools import BackendMethods, CpuBackend, Cuda11xBackend, Cuda12xBackend +from ..gpubackendtools import BackendMethods, CpuBackend, Cuda11xBackend, Cuda12xBackend, Cuda13xBackend from ..exceptions import * @dataclasses.dataclass @@ -126,6 +126,39 @@ def cuda12x_module_loader(): CubicSpline=gbt_backend_cuda12x.interp.CubicSplineGPU, xp=cupy, ) + +class GBTCuda13xBackend(Cuda13xBackend, GBTBackend): + """Implementation of CUDA 13.x backend""" + _backend_name : str = "gbt_backend_cuda13x" + _name = "gbt_cuda13x" + + def __init__(self, *args, **kwargs): + Cuda13xBackend.__init__(self, *args, **kwargs) + GBTBackend.__init__(self, self.cuda13x_module_loader()) + + @staticmethod + def cuda13x_module_loader(): + try: + import gbt_backend_cuda13x.interp + + except (ModuleNotFoundError, ImportError) as e: + raise BackendUnavailableException( + "'cuda13x' backend could not be imported." + ) from e + + try: + import cupy + except (ModuleNotFoundError, ImportError) as e: + raise MissingDependencies( + "'cuda13x' backend requires cupy", pip_deps=["cupy-cuda13x"] + ) from e + + return GBTBackendMethods( + interpolate_wrap=gbt_backend_cuda13x.interp.interpolate_wrap, + CubicSplineWrap=gbt_backend_cuda13x.interp.CubicSplineWrapGPU, + CubicSpline=gbt_backend_cuda13x.interp.CubicSplineGPU, + xp=cupy, + ) """List of existing backends, per default order of preference.""" # TODO: __all__ ?