From 56e78c6a7b1a3334d6f47aebb0df435223044aa3 Mon Sep 17 00:00:00 2001 From: youngrok-XCENA Date: Tue, 2 Jun 2026 19:14:52 +0900 Subject: [PATCH] fix: support latest vllm kv cache layout --- maru_vllm/connector.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/maru_vllm/connector.py b/maru_vllm/connector.py index 86375fe..a30f363 100644 --- a/maru_vllm/connector.py +++ b/maru_vllm/connector.py @@ -685,7 +685,11 @@ def start_load_kv( if kv_cache_attr is None: continue - kv_cache_layer = kv_cache_attr[forward_context.virtual_engine] + if isinstance(kv_cache_attr, (list, tuple)): + virtual_engine = getattr(forward_context, "virtual_engine", 0) + kv_cache_layer = kv_cache_attr[virtual_engine] + else: + kv_cache_layer = kv_cache_attr # Use the same layer index derivation as save path # to ensure key consistency (enumerate index may diverge # when no_compile_layers contains non-attention layers).