diff --git a/lmcache/v1/multiprocess/transfer_context/shm.py b/lmcache/v1/multiprocess/transfer_context/shm.py index 0178d27faf..7f833ba656 100644 --- a/lmcache/v1/multiprocess/transfer_context/shm.py +++ b/lmcache/v1/multiprocess/transfer_context/shm.py @@ -2,6 +2,7 @@ """Shared-memory NonGpuContext implementation for multiprocess mode.""" # Standard +import ctypes from dataclasses import dataclass from multiprocessing import shared_memory from multiprocessing.resource_tracker import unregister @@ -11,6 +12,8 @@ import torch # First Party +from lmcache import torch_dev +from lmcache.logging import init_logger from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey from lmcache.v1.multiprocess.mq import MessageQueueClient from lmcache.v1.multiprocess.protocol import RequestType, get_response_class @@ -19,6 +22,8 @@ NonGpuContextMetadata, ) +logger = init_logger(__name__) + @dataclass(frozen=True) class ShmSlotDescriptor: @@ -92,6 +97,9 @@ def __init__( self._pool_size = pool_size self._shm: shared_memory.SharedMemory | None = None self._shm_buffer: memoryview | None = None + self._pinned = False + self._pinned_ptr = 0 + self._pinned_size = 0 try: self._shm = shared_memory.SharedMemory( name=shm_name.lstrip("/"), create=False @@ -101,6 +109,7 @@ def __init__( # unlink the segment when this worker exits. unregister(f"/{self._shm.name}", "shared_memory") self._shm_buffer = self._shm.buf + self._register_shm_buffer() except Exception: self._shm = None self._shm_buffer = None @@ -212,7 +221,70 @@ def close(self) -> None: if self._shm is None: return try: - self._shm.close() + self._unregister_shm_buffer() finally: - self._shm = None - self._shm_buffer = None + try: + self._shm.close() + finally: + self._shm = None + self._shm_buffer = None + + def _register_shm_buffer(self) -> None: + if self._shm_buffer is None or not torch_dev.is_available(): + return + if not hasattr(torch_dev, "cudart"): + logger.warning( + "Skipping SHM host registration for shm_name=%s: " + "backend does not support cudart(); D2H copies will be synchronous", + self._shm_name, + ) + return + try: + ptr = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_buffer)) + err = torch_dev.cudart().cudaHostRegister(ptr, self._pool_size, 0) + except Exception as exc: + logger.warning( + "Failed to register SHM buffer for shm_name=%s: %r; " + "D2H copies will be synchronous", + self._shm_name, + exc, + ) + return + if err != 0: + logger.warning( + "cudaHostRegister failed for shm_name=%s (ptr=%d, size=%d, err=%s); " + "D2H copies will be synchronous", + self._shm_name, + ptr, + self._pool_size, + err, + ) + return + self._pinned = True + self._pinned_ptr = ptr + self._pinned_size = self._pool_size + + def _unregister_shm_buffer(self) -> None: + if not self._pinned or self._pinned_ptr == 0: + return + try: + err = torch_dev.cudart().cudaHostUnregister(self._pinned_ptr) + if err != 0: + logger.warning( + "cudaHostUnregister failed for shm_name=%s (ptr=%d, size=%d, " + "err=%s)", + self._shm_name, + self._pinned_ptr, + self._pinned_size, + err, + ) + except Exception as exc: + logger.warning( + "Failed to unregister SHM buffer for shm_name=%s: %r", + self._shm_name, + exc, + ) + finally: + self._pinned = False + self._pinned_ptr = 0 + self._pinned_size = 0 diff --git a/tests/v1/multiprocess/test_non_cuda_data_transfer.py b/tests/v1/multiprocess/test_non_cuda_data_transfer.py index c60290b917..1aaa38b6f7 100644 --- a/tests/v1/multiprocess/test_non_cuda_data_transfer.py +++ b/tests/v1/multiprocess/test_non_cuda_data_transfer.py @@ -1162,3 +1162,134 @@ def test_non_gpu_context_shm_close_is_idempotent() -> None: finally: if os.path.exists(shm_path): os.unlink(shm_path) + + +def test_non_gpu_context_shm_registers_and_unregisters_host_memory( + monkeypatch: Any, +) -> None: + shm_name = f"lmcache_test_pin_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + + class FakeCudaRt: + def __init__(self) -> None: + self.register_calls: list[tuple[int, int, int]] = [] + self.unregister_calls: list[int] = [] + + def cudaHostRegister(self, ptr: int, size: int, flags: int) -> int: + self.register_calls.append((ptr, size, flags)) + return 0 + + def cudaHostUnregister(self, ptr: int) -> int: + self.unregister_calls.append(ptr) + return 0 + + class FakeTorchDev: + def __init__(self, cudart: FakeCudaRt) -> None: + self._cudart = cudart + + def is_available(self) -> bool: + return True + + def cudart(self) -> FakeCudaRt: + return self._cudart + + # First Party + import lmcache.v1.multiprocess.transfer_context.shm as shm_module + + fake_cudart = FakeCudaRt() + monkeypatch.setattr(shm_module, "torch_dev", FakeTorchDev(fake_cudart)) + + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + assert len(fake_cudart.register_calls) == 1 + ptr, size, flags = fake_cudart.register_calls[0] + assert ptr > 0 + assert size == 4096 + assert flags == 0 + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) + + assert fake_cudart.unregister_calls == [ptr] + + +def test_non_gpu_context_shm_register_failure_warns_and_skips_unregister( + monkeypatch: Any, +) -> None: + shm_name = f"lmcache_test_pin_fail_{os.getpid()}" + shm_path = _create_shm_file(shm_name, 4096) + + class FakeCudaRt: + def __init__(self) -> None: + self.register_calls: list[tuple[int, int, int]] = [] + self.unregister_calls: list[int] = [] + + def cudaHostRegister(self, ptr: int, size: int, flags: int) -> int: + self.register_calls.append((ptr, size, flags)) + return 1 + + def cudaHostUnregister(self, ptr: int) -> int: + self.unregister_calls.append(ptr) + return 0 + + class FakeTorchDev: + def __init__(self, cudart: FakeCudaRt) -> None: + self._cudart = cudart + + def is_available(self) -> bool: + return True + + def cudart(self) -> FakeCudaRt: + return self._cudart + + # First Party + import lmcache.v1.multiprocess.transfer_context.shm as shm_module + + fake_cudart = FakeCudaRt() + monkeypatch.setattr(shm_module, "torch_dev", FakeTorchDev(fake_cudart)) + + with patch.object(shm_module.logger, "warning") as warning_mock: + context = NonGpuContextShm( + metadata=NonGpuContextMetadata( + layout_desc=MemoryLayoutDesc( + shapes=[torch.Size([2, 2])], + dtypes=[torch.float32], + ), + block_size=1, + use_mla=False, + ), + mq_client=MagicMock(), + mq_timeout=1.0, + shm_name=shm_name, + pool_size=4096, + ) + try: + assert len(fake_cudart.register_calls) == 1 + finally: + context.close() + if os.path.exists(shm_path): + os.unlink(shm_path) + + assert fake_cudart.unregister_calls == [] + warning_mock.assert_called_once() + message, logged_shm_name, _logged_ptr, logged_size, logged_err = ( + warning_mock.call_args[0] + ) + assert "cudaHostRegister failed" in message + assert logged_shm_name == shm_name + assert logged_size == 4096 + assert logged_err == 1