Async data path#368
Open
hlin99 wants to merge 28 commits into
Open
Conversation
Drop the dead-field cases (resumed_from_preemption / evicted_req_ids), which do not exist on vLLM main's CachedRequestData / SchedulerOutput, and keep only the real signals (resumed_req_ids, preempted_req_ids) plus the conservative unknown-schema fallback.
_scheduler_step_needs_flush previously probed two fields that do not exist on vLLM main's schema: CachedRequestData.resumed_from_preemption (replaced by resumed_req_ids) and SchedulerOutput.evicted_req_ids (never existed). Those getattr checks were dead code and the comment was inaccurate. Verified against vLLM main (vllm/v1/core/sched/output.py): - CachedRequestData.resumed_req_ids: set[str] -> real resume signal - SchedulerOutput.preempted_req_ids: set[str] | None -> real preempt signal (populated unconditionally in scheduler.py) Keep only those two real signals plus the conservative unknown-schema fallback (flush when scheduled_cached_reqs lacks resumed_req_ids). This matches the test cleanup in the previous commit; behavior on real vLLM is unchanged.
…eMode error in commit thread PyTorch's InferenceMode propagates to child threads. The commit thread inherits InferenceMode from the vLLM EngineCore main thread, causing `shm_view.copy_(staged)` to raise: "Inplace update to inference tensor outside InferenceMode is not allowed" Fix by explicitly exiting InferenceMode for the inplace copy operation.
…ant staging copy When SHM out_buffers are available from prepare_store(), gather directly into them on the copy stream — matching the synchronous DataTransferContext behavior. This removes: 1. The redundant pinned staging buffer allocation for SHM path 2. The staged→shm_view copy_ in the commit thread 3. The InferenceMode error caused by that copy_ Only the pickle path (no SHM) still uses pinned staging buffers.
Previously, submit_store performed the gather kernel launch (including _event.wait() and gather_paged_kv_to_cpu()) directly on the forward thread. When the copy stream has a pending event-wait (for the forward pass to finish), CUDA runtime throttles the CPU as kernels queue up on a stream with unresolved dependencies, blocking the forward thread for ~38ms on every store. This commit moves the entire gather phase into the background _commit_after_gather thread via the commit_executor. The forward thread now only does lightweight preparation (prepare_store, buffer allocation) and immediately submits the work and returns. Background thread now: 1. Acquires copy stream context 2. Inserts event-level wait for forward completion 3. Launches gather_paged_kv_to_cpu() 4. Records gather_done event on copy stream 5. Adds gather_done to _inflight_gather_events (under lock) 6. Synchronizes gather_done (waits for GPU gather to finish) 7. Calls commit_store() and resolves the future Also removes profiling remnants: import time, t00/t1/t2/t3/t4/t11 timing variables, Store Profiler logger.info calls, and the two torch_dev.synchronize() calls that were added for profiling only.
Signed-off-by: Tony Lin <tony.lin@intel.com>
- worker_transfer.py: Add import time + timing to HandleTransferContext.submit_store() with [FWD-IPC] log covering ipc_handle, send_request, to_cuda_future, and total ms - gpu_transfer.py: Add granular timing to GPUTransferModule.store() with [GPU-STORE] summary log and per-chunk [GPU-STORE-CHUNK] logs covering kernel launch and memcpy_d2h
…ed MP transfer primitive (LMCache#3508) Signed-off-by: Tony Lin <tony.lin@intel.com>
…notable speedup (LMCache#3591) * Perf: optimize Python fallback block transfer for 3x speedup - Optimize fallback block-id and D2H staging overhead - Restructure per-layer transfer loops to iterate over objects first then layers Signed-off-by: Tony Lin <tony.lin@intel.com> * apply gemini's suggestion Signed-off-by: Tony Lin <tony.lin@intel.com> * optimize flash_infer block transfer paths in python fallback Signed-off-by: Tony Lin <tony.lin@intel.com> --------- Signed-off-by: Tony Lin <tony.lin@intel.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this PR does / why we need it:
Special notes for your reviewers:
If applicable: