Skip to content

feat(kvstore): support mamba l2 cache transfers#162

Open
XucSh wants to merge 4 commits into
mainfrom
Xuchun/mamba-l2
Open

feat(kvstore): support mamba l2 cache transfers#162
XucSh wants to merge 4 commits into
mainfrom
Xuchun/mamba-l2

Conversation

@XucSh
Copy link
Copy Markdown
Contributor

@XucSh XucSh commented May 15, 2026

Summary

This PR adds L2 cache support for Mamba cache alongside the existing KV cache L2 path.

The core change is to generalize cache movement into CacheTransferUnit, which carries the cache kind (KV or Mamba) plus source and destination slots/pages. WriteBackOperation and LoadBackOperation can now carry mixed KV/Mamba transfer units without adding Mamba-specific scheduler branches everywhere.

Key behavior:

  • Mamba cache can be written back from GPU to host memory.
  • Mamba cache can be loaded back from host memory to GPU.
  • Mamba host memory supports eviction when host Mamba slots are exhausted.
  • Mamba loadback completion is synchronized through stream fence semantics, not scheduler loadback-done events.
  • Prefix reuse now requires all cache types for a token prefix to exist. KV-only cache is not considered reusable when the corresponding Mamba state is missing.
  • Mamba state depth is tracked explicitly so page-aligned prefix states and non-exact retraction recovery states are not confused.

Important design points:

  • CacheTransferUnit avoids duplicating writeback/loadback logic for KV and Mamba.
  • TreeNode tracks device/host Mamba slots and their token depth.
  • FindLastMambaNode(..., require_exact_depth=true) is used for prefix cache reuse.
  • Retraction recovery may use non-exact live Mamba state, but normal prefix reuse may not.
  • Host-side Mamba eviction detaches only the host Mamba state; future prefix reuse still requires KV and Mamba to both be present.

Test Plan

CUDA_VISIBLE_DEVICES=4,5,6,7 \ ts serve Qwen/Qwen3.5-122B-A10B \ --tp 4 \ --max-num-seqs 16 \ --max-total-tokens 160000 \ --max-model-len 80000 \ --chunked-prefill-size 128 \ --max-prefill-tokens 128 \ --block-size 64 \ --max-mamba-cache-size 64 \ --mamba-track-interval 64 \ --kvstore-ratio 1.0 \ --kvstore-io-backend kernel \ --disable-prefill-graph \ --enforce-eager \ --disable-overlap-schedule \ --enable-cache-report \ --host 127.0.0.1 \ --port 8000

evalscope eval \ --model Qwen/Qwen3.5-122B-A10B \ --api-url http://127.0.0.1:8000/v1 \ --api-key EMPTY_TOKEN \ --datasets aime25 \ --eval-batch-size 16 \ --generation-config '{"max_tokens":100000}'

mamba loadback/writeback can be seen. The result:

image

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@XucSh XucSh requested a review from a team as a code owner May 15, 2026 10:20
@XucSh XucSh marked this pull request as draft May 15, 2026 10:36
XucSh added 2 commits May 15, 2026 15:29
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@XucSh XucSh marked this pull request as ready for review May 16, 2026 01:15
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@XucSh XucSh requested review from SimonCqk and tuanzhangCS May 18, 2026 04:49
}

return Draining{BuildWriteBackPairs(write_diff), std::move(device_node_ref), std::move(host_node_ref)};
if (need_mamba_writeback) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we only call write_back when a request finishes or is retracted, and we only write back the last Mamba slot of that request. Should we also drain Mamba slots during the prefill stage to improve the cache hit rate?

}
hybrid_prefix_cache_->AttachHostMamba(terminal, std::move(host_slot), terminal->MambaDeviceDepthTokens());
pages_to_transfer.push_back(
{CacheTransferKind::Mamba, terminal->MambaSlotIndex(), terminal->MambaHostSlotIndex()});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the mamba_host_slot associated with a retract request be protected from eviction?


def write(self, executor, op: _TransferOp, prepared) -> None:
executor._copy_mamba_slots(
executor.mamba_pool,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ, in KVTransferBackend, there are draft_pool write/load operations, will mamba be possible with a draft pool ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current draft version of Qwen3.5 does not include Mamba layers; however, I cannot confirm whether a Mamba pool will be introduced in the future or other models.

auto slot = hybrid_prefix_cache_->AllocateDeviceMamba();
if (slot == nullptr) return {};
hybrid_prefix_cache_->LoadBackMamba(node, std::move(slot));
match_result.mamba_cow_src_index = node->MambaSlotIndex();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mamba_loadback_diff is a vector and for now its length is alway 1, but overwrite mamba_cow_src_index in for-loop seems weird, should it be lifted as a vector too?

conv_dtype=self.conv_dtype,
ssm_dtype=self.ssm_dtype,
mamba_layer_ids=self.mamba_layer_ids,
device="cpu",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the memory of mamba host pool be pinned?

@SimonCqk
Copy link
Copy Markdown
Contributor

LGTM overall, since #146 was just merged, there are conflicts in some files, resolve them to merge 🚀🙌🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants