Skip to content

Register worker SHM mappings for truly async D2H KV copies#369

Draft
Copilot wants to merge 3 commits into
copilot/ww24-pr-async-againfrom
copilot/fix-gpu-copy-non-blocking
Draft

Register worker SHM mappings for truly async D2H KV copies#369
Copilot wants to merge 3 commits into
copilot/ww24-pr-async-againfrom
copilot/fix-gpu-copy-non-blocking

Conversation

Copilot AI commented Jun 16, 2026

Copy link
Copy Markdown

gather_paged_kv_to_cpu(..., out=...) was issuing non_blocking=True D2H copies into SHM-backed CPU buffers, but those worker-side mappings were not registered as pinned host memory in the worker’s CUDA context. As a result, copies to SHM fell back to synchronous behavior despite using the async path.

  • Worker-side SHM host registration

    • Register the attached SHM buffer in NonGpuContextShm immediately after shared_memory.SharedMemory(create=False) succeeds.
    • Use the mapped buffer address and pool size with torch_dev.cudart().cudaHostRegister(..., 0) so the worker’s virtual mapping is recognized by the GPU DMA path.
  • Lifecycle tracking and cleanup

    • Track registration state on the context (_pinned, pointer, size).
    • Unregister the worker mapping with cudaHostUnregister during close() before releasing the SHM segment.
  • Graceful fallback on unsupported backends / registration failure

    • Guard registration behind torch_dev.is_available() and hasattr(torch_dev, "cudart").
    • If registration is unavailable or fails, log a warning and continue with the existing synchronous fallback behavior.
  • Focused coverage

    • Add tests covering successful worker-side registration/unregistration.
    • Add tests covering registration failure: warning emitted, no unregister attempt.
self._shm = shared_memory.SharedMemory(name=shm_name.lstrip("/"), create=False)
self._shm_buffer = self._shm.buf

ptr = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_buffer))
err = torch_dev.cudart().cudaHostRegister(ptr, self._pool_size, 0)

Copilot AI changed the title [WIP] Fix non-blocking GPU memory copy in gather_paged_kv_to_cpu Register worker SHM mappings for truly async D2H KV copies Jun 16, 2026
Copilot AI requested a review from hlin99 June 16, 2026 05:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants