Skip to content

feat/maru backend#29

Merged
jooho-XCENA merged 11 commits into
mainfrom
feat/maru_backend
Apr 3, 2026
Merged

feat/maru backend#29
jooho-XCENA merged 11 commits into
mainfrom
feat/maru_backend

Conversation

@hyunyul-XCENA

Copy link
Copy Markdown
Collaborator

Summary

Add LMCache MaruBackend support with CXL zero-copy memory adapter, fixed pool allocation with auto-expand, and server-side pin/unpin RPC for eviction protection. This consolidates the Maru-LMCache integration from the old connector-based approach to a direct AllocatorBackendInterface implementation backed by CXL shared memory.

Key Changes

MaruBackend & CxlMemoryAdapter (#25)

  • Replace MaruConnector/MaruConnectorAdapter with CxlMemoryAdapter implementing LMCache's MemoryAllocatorInterface
  • Add MaruHandler APIs: pin, unpin, batch_pin, batch_unpin, batch_store, batch_retrieve
  • Add single/disagg/p2p example scripts and configs
  • Delete old maru_lmcache/connector.py (642 lines removed)

Fixed Pool Allocation (#26)

  • Add OwnedRegionManager.fixed_alloc() for deterministic page-level allocation
  • Support optional auto-expand when owned regions are exhausted
  • Add set_on_region_added callback for dynamic region tracking

Server-side Pin/Unpin RPC (#27)

  • Add EXISTS_AND_PIN_KV (0x14), UNPIN_KV (0x15), BATCH_EXISTS_AND_PIN_KV (0x23), BATCH_UNPIN_KV (0x24) protocol messages
  • Add pin_count field on KVEntry with prefix-stop batch pin semantics
  • Pinned entries are protected from deletion (delete() refuses when pin_count > 0)

Test Plan

  • Unit tests added/updated (test_cxl_memory_adapter, test_kv_manager, test_maru_handler, test_maru_backend)
  • Existing tests pass (pytest -v)
  • Integration tests updated (test_handler, test_maru_integration)
  • E2E validated with P2P sharing example (BATCH_EXISTS_AND_PIN / BATCH_UNPIN flow)

Related Issues

youngrok-XCENA and others added 3 commits March 17, 2026 11:49
* feat: maru storage backend bring-up

* refactor: rename CxlMemoryAllocator to CxlMemoryAdapter with facade pattern

- Add MaruHandler facade API (get_buffer_view, get_region_page_count,
  get_owned_region_ids, get_chunk_size) to eliminate Law of Demeter violations
- Add set_on_region_added callback with replay for region expansion support
- Rename CxlMemoryAllocator → CxlMemoryAdapter to reflect adapter role
- Unify pool build path: callback handles both init (replay) and expansion
- Remove deprecated connector/adapter tests (storage backend replaces connector)
- Fix _build_region_pool to abort on buffer failure instead of skipping pages

* fix: address medium/low review feedback

allocator.py:
- Implement free()/batched_free() to return handler pages (was no-op)
- Use fmt.token_dim() instead of hardcoded shape index

connector.py (deprecated):
- Fix handle leak in _batch_store() on partial alloc failure
- Change zip strict=False to strict=True

handler.py:
- Fix docstring: key=12345 → key="12345"

* test: add async lookup API tests for MaruBackend

Cover batched_async_contains and batched_get_non_blocking with 7 tests
mirroring the connector-era TestBatchOperations coverage: all-hit,
partial prefix, first miss, empty keys, and prefix stop-on-miss.

* refactor: update example configs and scripts for MaruBackend

- Migrate configs from connector (remote_url/remote_storage_plugins)
  to native MaruBackend (maru_path/maru_pool_size)
- Improve disagg_example_1p1d.sh process management: setsid for
  process groups, stale PID cleanup, sequential instance launch
- Enable vLLM request logging for debugging
- Fix p2p_example.sh launch order: wait for inst1 before starting inst2

* refactor: remove deprecated MaruConnector and MaruConnectorAdapter

Storage backend (MaruBackend) replaced the connector approach.
These modules are no longer imported by LMCache or any tests.
Docs references will be cleaned up separately.

* fix: handle partial chunk in CxlMemoryAdapter.allocate()

allocate() returned pre-built pool objects with full chunk_size shape
even for partial chunks. When the CUDA kernel in multi_layer_kv_transfer
reads num_tokens from key_value.size(2), it launched more blocks than
slot_mapping entries, causing GPU OOB read on slot_mapping.

Call _create_partial_view() in allocate() when requested size < chunk_size
so that memory_obj.tensor shape matches slot_mapping size. Also enable
save_unfull_chunk in PD example configs.

* refactor: rename allocator.py to adapter.py to match CxlMemoryAdapter class name

* test: initialize PinMonitor in MaruBackend test fixture

get_blocking calls memory_obj.pin() which requires PinMonitor singleton.
Add autouse fixture to initialize/teardown PinMonitor per test.
* feat: add fixed pool allocation with optional auto-expand

- Add auto_expand (default False) and expand_size config options
- Change connect() to aggregate regions across multiple pools
- Gate _expand_region() on auto_expand flag, use expand_size
- Improve alloc() error messages (disabled vs failed expansion)
- Add 13 new tests for config validation and behavior

* style: fix ruff format for test files

* test: skip cxl_memory_adapter tests when torch/lmcache not installed

* fix: default auto_expand to True and fix assertion operator precedence bug
* feat: add EXISTS_AND_PIN_KV and UNPIN_KV RPC for server-side eviction protection

- Add EXISTS_AND_PIN_KV (0x14) and UNPIN_KV (0x15) message types to protocol
- Add pin_count field to KVEntry for tracking pinned entries
- Add exists_and_pin() and unpin() methods to KVManager
- Add exists_and_pin_kv() and unpin_kv() to MaruServer
- Add RPC handlers for new message types
- Add client methods to MaruHandler, RpcClient, and RpcAsyncClient

* feat: add batch pin/unpin RPC operations

- Add BATCH_EXISTS_AND_PIN_KV (0x23), BATCH_UNPIN_KV (0x24), BATCH_PIN_KV (0x25) message types
- Add batch_exists_and_pin(), batch_pin(), batch_unpin() to KVManager
- Add corresponding MaruServer, RPC handler, and client methods
- Enables single-RPC batch operations instead of N individual calls

* test: update tests for batch RPC and ref_count changes

- test_ref_count_managed_during_put: expect initial_ref + 1 (pool ref retained)
- batched_async_contains tests: mock batch_exists instead of individual exists

* fix: remove pre-commit config, resolve ruff lint/format errors, skip torch-dependent tests in CI

* refactor: clean up pin/unpin RPC and add server-side pin timeout monitor

- Remove redundant PIN_KV (0x16), EXISTS_AND_PIN_KV (0x14), BATCH_PIN_KV
  (0x25) — only BATCH_EXISTS_AND_PIN_KV, UNPIN_KV, BATCH_UNPIN_KV remain
- Renumber UNPIN_KV to 0x14 (fill gap from removed ops)
- Add pin timeout monitor to KVManager/MaruServer (daemon thread, 30s
  interval, 60s timeout) to force-unpin leaked entries
- Refuse delete on pinned entries (pin_count > 0) with warning log
- Make batch_exists_and_pin prefix-aware: stop at first miss to prevent
  pin leaks on non-prefix keys
- Propagate exceptions from batch_exists_and_pin/batch_unpin_kv to caller
  instead of silently returning [False]*N
- Add hit/ok counts to batch RPC handler logs
- Update tests to match LMCache API changes (batch_exists, batch_retrieve,
  single-future batched_submit_put_task)

* feat: restore single EXISTS_AND_PIN_KV RPC and remove PinMonitor

- Restore EXISTS_AND_PIN_KV (0x14) single-key RPC across full stack
  (protocol, kv_manager, server, rpc_handler, rpc_client, handler)
- Renumber UNPIN_KV to 0x15
- Remove PinMonitor daemon thread and pin_timestamps tracking
  (no eviction yet, so pin leaks have no practical impact)
- Add TODO comments for future PinMonitor when eviction is implemented

* fix: add pinned count to BATCH_EXISTS_AND_PIN log message

* fix: restore .pre-commit-config.yaml

* fix: update ref_count test to match single ref_count_up in submit_put_task

* fix: address PR #27 review feedback

- delete() returns DeleteResult enum (NOT_FOUND/PINNED/DELETED) instead
  of ambiguous (bool, None) tuple
- Rename exists_and_pin -> pin, batch_exists_and_pin -> batch_pin across
  protocol, server, handler, and RPC clients
- Rename unpin_kv -> unpin, batch_unpin_kv -> batch_unpin for naming
  consistency in handler public API
- Wrap batch_pin/batch_unpin RPC returns in BatchPinKVResponse/
  BatchUnpinKVResponse for return type consistency
- Fix redundant log parameters in BATCH_PIN handler
- Add docstring clarification for batch_exists vs batch_pin semantics
- Add unit tests for pin/unpin/batch_pin/batch_unpin/delete-pinned
- Export Pin/Unpin protocol types from maru_common

* style: ruff format
@github-actions

github-actions Bot commented Mar 20, 2026

Copy link
Copy Markdown

@hyunyul-XCENA hyunyul-XCENA changed the title Feat/maru backend feat/maru backend Mar 20, 2026
seohui-XCENA and others added 5 commits March 20, 2026 08:31
Reflect the migration from RemoteConnector to AllocatorBackendInterface:
- Replace MaruConnector references with MaruBackend + CxlMemoryAdapter
- Update data path diagrams (store/retrieve/pin-unpin)
- Simplify config section (maru_path + maru_pool_size)
- Update p2p and pd example configs
- Update component architecture diagram
Align with LMCache config change where maru_pool_size is now
float (GB) instead of str (e.g. "4G"). Update all examples,
docs, and tests accordingly.

@youngrok-XCENA youngrok-XCENA left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 리뷰: feat/maru backend

이 PR이 필요한 이유 (미적용시 문제점)

  1. Zero-copy 미활용: 기존 MaruConnector 기반 통합은 LMCache의 RemoteConnector 인터페이스에 의존하여 데이터를 TCP 직렬화/역직렬화합니다. CXL 공유 메모리의 zero-copy 이점을 전혀 활용하지 못하며, store/retrieve마다 불필요한 데이터 복사가 발생합니다.
  2. P2P/PD 데이터 무결성 위험: Pin/Unpin 메커니즘이 없으면, disaggregated prefill이나 P2P 공유 시나리오에서 한 인스턴스가 KV 캐시 항목을 읽는 동안 다른 인스턴스에서 해당 항목을 삭제/회수할 수 있어 use-after-free 데이터 손상이 발생합니다.
  3. Pool 소진 시 OOM: auto_expand가 없으면 초기 할당된 메모리 풀이 소진되었을 때 추가 할당이 불가능하여 서비스가 중단됩니다.

설계 변경점

Before: Connector 기반 (간접 접근, 데이터 복사 발생)

graph LR
    A["LMCache Engine"] --> B["RemoteConnector<br/>(plugin interface)"]
    B --> C["MaruConnectorAdapter"]
    C --> D["MaruConnector<br/>(TCP serialize/deserialize)"]
    D --> E["MaruHandler"]
    E --> F["MaruServer + KVManager"]
    style D fill:#f99,stroke:#333
Loading

After: Backend 직접 통합 (zero-copy, pin/unpin 지원)

graph LR
    A["LMCache Engine"] --> B["MaruBackend<br/>(AllocatorBackendInterface)"]
    B --> C["CxlMemoryAdapter<br/>(MemoryAllocatorInterface)"]
    C --> D["MaruHandler"]
    B -.->|"store/retrieve<br/>pin/unpin"| D
    D --> E["OwnedRegionManager<br/>+auto_expand<br/>+on_region_added callback"]
    D --> F["MaruServer"]
    F --> G["KVManager<br/>+pin_count<br/>+DeleteResult enum<br/>+prefix-stop batch_pin"]
    style C fill:#9f9,stroke:#333
    style G fill:#9f9,stroke:#333
Loading

핵심 아키텍처 변경:

  • MaruConnector(TCP 직렬화) 계층 제거 → CxlMemoryAdapter가 CXL 공유 메모리에 직접 접근
  • Address encoding: (region_id << 32) | page_index — O(1) 양방향 변환, 누적 오프셋 테이블 불필요
  • Pool 구축이 on_region_added 콜백 기반: 초기화(replay)와 확장(expansion) 모두 단일 경로로 처리
  • Server-side pin/unpin: prefix-stop 시맨틱스로 pin leak 방지, DeleteResult enum으로 삭제 거부 사유 구분

주요 변경 요약

영역 변경 주요 파일
LMCache 통합 MaruConnector → CxlMemoryAdapter + MaruBackend adapter.py (신규), connector.py (삭제)
프로토콜 PIN_KV(0x14), UNPIN_KV(0x15), BATCH_PIN(0x23), BATCH_UNPIN(0x24) protocol.py
서버 pin_count, DeleteResult enum, prefix-stop batch_pin kv_manager.py, server.py
핸들러 pin/unpin/batch API, free() 구현, auto_expand handler.py
설정 auto_expand, expand_size 옵션 config.py
테스트 93% 커버리지, 587 테스트 전체 통과 다수 테스트 파일

기존 코드 관련 참고사항

handler.pyfree() 메서드(line 437-443)에서 _key_to_location을 전체 순회(O(n))하여 (region_id, page_index)로 key를 역탐색합니다. 이번 PR의 변경 범위는 아니지만, KV 엔트리가 많아지면 성능 병목이 될 수 있으므로 후속 작업으로 _location_to_key reverse mapping 추가를 고려해 주세요.

총평

Connector → Backend 전환은 CXL zero-copy를 제대로 활용하는 올바른 방향이며, pin/unpin의 prefix-stop 시맨틱스, DeleteResult enum 도입은 견고한 설계입니다. 테스트 커버리지 93%도 우수합니다.

아래 인라인 코멘트에 코드 레벨 개선사항을 남겼으니 확인 부탁드립니다.

Comment thread maru_lmcache/adapter.py
MaruBackend.remove() -> handler.delete().
"""
rid, pid = self.decode_address(memory_obj.metadata.address)
handle = AllocHandle(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] free()에서 AllocHandle(buf=memoryview(b""), _size=0)으로 dummy handle을 생성하여 handler.free()를 호출하고 있습니다. handler.free()는 handle의 _region_id/_page_index만 사용하므로 동작에는 문제가 없지만, 빈 memoryview를 가진 handle 객체가 의미적으로 어색합니다.

MaruHandlerfree_by_location(region_id, page_index) 같은 direct method를 추가하면 이 우회 패턴을 제거하고 의도를 더 명확히 할 수 있습니다.

Comment thread maru_server/server.py
self._allocation_manager.decrement_kv_ref(region_to_deref)

return existed
return result == DeleteResult.DELETED

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] KVManager.delete()DeleteResult(NOT_FOUND/PINNED/DELETED)를 반환하도록 잘 개선했지만, 여기서 result == DeleteResult.DELETED로 bool 변환하면서 PINNED와 NOT_FOUND가 구분되지 않습니다.

DeleteKVResponse에도 success: bool만 있어서, 클라이언트가 삭제 실패 원인(핀 vs 부재)을 알 수 없습니다. Eviction 구현 시 "unpin 후 재시도" vs "키 없으니 무시"를 판단할 수 있도록, response에 result 필드를 전달하는 것을 고려해 주세요.

def batch_unpin(self, keys: list[str]) -> BatchUnpinKVResponse:
"""Unpin multiple KV entries in a single RPC call."""
response = self._send_request(MessageType.BATCH_UNPIN_KV, {"keys": keys})
return BatchUnpinKVResponse(results=response.get("results", []))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] RpcAsyncClientpin_kv_async, unpin_async, batch_pin_kv_async, batch_unpin_async non-blocking 메서드가 누락되어 있습니다.

다른 KV 연산(exists_kv_async, delete_kv_async, batch_register_kv_async 등)은 모두 *_async 변형이 있으므로, pin/unpin도 동일하게 추가하면 파이프라이닝이 가능합니다.

Comment thread maru_server/kv_manager.py
results = []
for key in keys:
entry = self._store.get(key)
if entry is None or entry.pin_count <= 0:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] unpin()(line 113-117)에서는 key 미존재 또는 pin_count <= 0일 때 logger.warning()으로 로그를 남기는데, batch_unpin()에서는 동일한 상황에서 로그 없이 False만 반환합니다. 디버깅 편의를 위해 일관된 로그 정책을 권장합니다.

Comment thread maru_lmcache/adapter.py
)
return

flat_dtype = self._dtypes[0]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] flat_dtype = self._dtypes[0]으로 항상 첫 번째 dtype만 사용합니다. 현재는 모든 레이어가 동일 dtype이므로 문제없지만, 향후 mixed-precision KV cache(key=fp16, value=fp8) 지원 시 이 가정이 깨질 수 있습니다.

방어적 assertion을 추가하는 것을 권장합니다:

assert all(d == self._dtypes[0] for d in self._dtypes), "mixed dtypes not yet supported"

@youngrok-XCENA youngrok-XCENA left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 리뷰: feat/maru backend

이 PR이 필요한 이유 (미적용시 문제점)

  1. Zero-copy 미활용: 기존 MaruConnector는 TCP 직렬화/역직렬화를 통해 LMCache와 통신하여, CXL 공유 메모리의 zero-copy 이점을 전혀 활용하지 못합니다.
  2. P2P/PD 데이터 무결성 위험: Pin/Unpin 메커니즘 없이 disaggregated prefill이나 P2P 공유 시 use-after-free 데이터 손상이 발생합니다.
  3. Pool 소진 시 OOM: auto_expand가 없으면 초기 메모리 풀 소진 시 서비스가 중단됩니다.

설계 변경점

Before: Connector 기반 (간접 접근, TCP 직렬화)

graph LR
    A[LMCache Engine] --> B[RemoteConnector]
    B --> C[MaruConnectorAdapter]
    C --> D[MaruConnector - TCP serialize]
    D --> E[MaruHandler]
    E --> F[MaruServer + KVManager]
    style D fill:#f99,stroke:#333
Loading

After: Backend 직접 통합 (zero-copy, pin/unpin 지원)

graph LR
    A[LMCache Engine] --> B[MaruBackend - AllocatorBackendInterface]
    B --> C[CxlMemoryAdapter - MemoryAllocatorInterface]
    C --> D[MaruHandler]
    B -.-> |store/retrieve, pin/unpin| D
    D --> E[OwnedRegionManager +auto_expand +on_region_added]
    D --> F[MaruServer]
    F --> G[KVManager +pin_count +DeleteResult +prefix-stop batch_pin]
    style C fill:#9f9,stroke:#333
    style G fill:#9f9,stroke:#333
Loading

핵심 아키텍처 변경:

  • MaruConnector(TCP 직렬화) 계층 제거 -> CxlMemoryAdapter가 CXL 공유 메모리에 직접 접근
  • Address encoding: (region_id << 32) | page_index -- O(1) 양방향 변환
  • Pool 구축이 on_region_added 콜백 기반: 초기화(replay)와 확장(expansion) 단일 경로
  • Server-side pin/unpin: prefix-stop 시맨틱스, DeleteResult enum으로 삭제 거부 사유 구분

리뷰 결과 요약

전반적으로 Connector -> Backend 전환은 올바른 방향이며, 93% 테스트 커버리지도 우수합니다.

아래 인라인 코멘트에 @youngrok-XCENA 님의 기존 리뷰에서 다루지 않은 추가 이슈들을 남겼습니다:

심각도 파일 이슈
HIGH handler.py batch_store 개별 등록 실패 시 페이지 릭
MEDIUM handler.py set_on_region_added 콜백 replay 시 thread-safety
MEDIUM kv_manager.py pin_count 상한 없음 + pin 메트릭 부재
MEDIUM rpc_client.py unpin/batch_unpin 네이밍 불일치 (_kv suffix 누락)
MEDIUM owned_region_manager.py Query 메서드에 lock 미적용

참고: PR description의 메시지 타입 이름(EXISTS_AND_PIN_KV)과 실제 코드(PIN_KV)가 다릅니다.

Comment thread maru_handler/handler.py
@@ -912,15 +875,7 @@ def batch_store(
if results[i] and i in allocations:
self._key_to_location[key] = allocations[i]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[HIGH] batch_store 개별 등록 실패 시 페이지 릭

batch_resp.results[batch_idx]False인 경우(서버가 개별 키 등록을 거부), results[i]False로 설정되지만 allocations[i]에 할당된 페이지가 해제되지 않습니다.

전체 배치 실패(exception 또는 batch_resp.success == False)만 페이지를 해제하고, 개별 실패는 누락됩니다.

장기 실행 프로세스에서 concurrent writer들의 race-lost 등록이 CXL 페이지를 영구적으로 누출하여 풀을 소진시킬 수 있습니다.

제안: batch result mapping 루프 이후에 개별 실패 항목의 페이지 해제 추가:

for i, (rid, pidx) in allocations.items():
    if not results[i]:
        self._owned.free(rid, pidx)

Comment thread maru_handler/handler.py
Args:
callback: Called with (region_id, page_count), or None to unregister.
"""
self._on_region_added = callback

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] set_on_region_added 콜백 설정/replay 시 thread-safety 미보장

self._on_region_added를 동기화 없이 설정하고, self._owned 리전을 _write_lock 없이 순회합니다.

_expand_region()_write_lock 하에서 리전을 추가하면서 콜백을 호출하는데, set_on_region_added가 동시에 실행되면 replay 루프와 expansion 콜백이 race하여 중복 또는 누락된 리전 알림이 발생할 수 있습니다.

제안: _write_lock을 메서드 전체에 적용:

def set_on_region_added(self, callback):
    with self._write_lock:
        self._on_region_added = callback
        if callback is not None and self._owned is not None:
            for rid in self._owned.get_region_ids():
                ...

Comment thread maru_server/kv_manager.py
entry = self._store.get(key)
if entry is None:
return False
entry.pin_count += 1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] pin_count 상한 없음 -- 무한 증가 가능 + pin 메트릭 부재

pin()batch_pin() 모두 entry.pin_count += 1을 상한 체크 없이 수행합니다. buggy client가 pin_count를 무한 증가시킬 수 있습니다.

server.py:client_disconnected()는 할당만 해제하고 pin은 해제하지 않으므로, disconnect 후에도 높아진 pin_count가 유지되어 엔트리가 영구 삭제 불가능해집니다.

또한 get_stats()에 pin 관련 메트릭이 없어 운영 환경에서 pin leak을 탐지할 수 없습니다.

제안 1 -- 상한 추가:

MAX_PIN_COUNT = 256
if entry.pin_count >= MAX_PIN_COUNT:
    logger.warning("Pin refused: key=%s max=%d", key, MAX_PIN_COUNT)
    return False

제안 2 -- get_stats()에 pin 메트릭 추가:

pinned = [e for e in self._store.values() if e.pin_count > 0]
stats["pinned_entries"] = len(pinned)
stats["total_pin_count"] = sum(e.pin_count for e in pinned)

"""
response = self._send_request(MessageType.PIN_KV, {"key": key})
return response.get("exists", False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] unpin/batch_unpin 네이밍 불일치 -- _kv suffix 누락

모든 KV 연산이 _kv suffix를 사용합니다: exists_kv, delete_kv, register_kv, lookup_kv, pin_kv, batch_pin_kv.
그러나 unpin()batch_unpin()_kv suffix가 없습니다.

이 불일치가 RpcClient, RpcAsyncClient, MaruServer, KVManager 4개 클래스에 걸쳐 있으며, _kv 패턴으로 API를 탐색하는 개발자가 unpin 메서드를 누락할 수 있습니다. 나중에 rename하면 breaking change가 됩니다.

제안: 지금 unpin_kv(), batch_unpin_kv()로 통일.

return self._chunk_size

def get_region_ids(self) -> list[int]:
"""Get list of owned region IDs in insertion order."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] Query 메서드(is_owned, is_full, get_stats, get_region_ids)에 _lock 미적용

add_region(), close(), allocate(), free()self._lock으로 보호되지만, is_owned()(line 196), is_full(line 200), get_stats()(line 220), get_region_ids()(여기) 등은 _regions를 lock 없이 읽습니다.

is_full은 multi-step 읽기를 수행하고, get_stats()_region_order 순회 중 _regions를 indexing합니다. close()가 동시에 _regions를 clear하면 KeyError 또는 잘못된 결과가 발생할 수 있습니다.

특히 is_owned()handler.pyretrieve/batch_retrieve 읽기 경로에서 _write_lock 없이 호출됩니다.

제안: Query 메서드에도 self._lock 적용:

def is_owned(self, region_id: int) -> bool:
    with self._lock:
        return region_id in self._regions

def get_region_ids(self) -> list[int]:
    with self._lock:
        return list(self._region_order)

@youngrok-XCENA youngrok-XCENA left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 29 리뷰: feat/maru backend (CXL Zero-Copy 아키텍처)

이 PR이 필요한 이유

기존 MaruConnector는 TCP 소켓을 통해 KV 캐시 데이터를 직렬화/역직렬화하며 복사했습니다. GPU 메모리와 CXL 공유 메모리 간 데이터 이동에서 불필요한 CPU 복사가 병목이 되어, 대규모 LLM 서빙 시 레이턴시와 처리량이 저하됩니다.

이 PR은 MaruBackend + CxlMemoryAdapter 구조로 전환하여:

  • CXL 공유 메모리에 직접 텐서를 매핑 (zero-copy)
  • Pin/Unpin RPC로 eviction 보호
  • 배치 연산으로 RPC 왕복 최소화

설계 변경점

flowchart TD
    subgraph Before
        A[LMCache Engine] --> B[MaruConnector]
        B -->|TCP serialize/deserialize| C[MaruServer]
        C --> D[CXL Memory]
    end

    subgraph After
        E[LMCache Engine] --> F[MaruBackend]
        F --> G[CxlMemoryAdapter]
        G -->|MemoryAllocatorInterface| H[MaruHandler]
        H -->|zero-copy mmap| I[CXL Memory]
        H -->|RPC pin/unpin/batch| J[MaruServer]
        J --> K[KVManager with pin_count]
        J --> L[AllocationManager]
    end
Loading

batch_store 에러 처리 흐름:

flowchart TD
    S[batch_store 시작] --> A1[batch_pin RPC 호출]
    A1 --> A2[prefix-stop 결과 수신]
    A2 --> C1[히트된 청크: 기존 캐시 반환]
    A2 --> C2[미스된 청크: 페이지 할당]
    C2 --> D1[GPU to CXL 복사]
    D1 --> D2[batch_register RPC]
    D2 --> E1[성공: 등록 완료]
    D2 --> E2[실패 또는 중복: 페이지 해제]
    E1 --> F[batch_unpin 호출]
    E2 --> F
    C1 --> F
    F --> G[결과 반환]
Loading

리뷰 요약

no 파일 라인 심각도 요약
1 adapter.py 131 MEDIUM parent_allocator=None AttributeError 위험
2 adapter.py 160 MEDIUM ensure_region_pool TOCTOU 경합 조건
3 handler.py 831 MEDIUM batch_exists fallback 페이지 누수 증폭
4 handler.py 848 MEDIUM batch_register 응답 길이 불일치 팬텀 엔트리

기존 리뷰어(youngrok-XCENA)가 지적한 10개 항목과 중복되지 않는 새로운 발견 사항입니다.

전체 평가

CXL zero-copy 아키텍처 전환은 잘 설계되었습니다. 어댑터 패턴, 주소 인코딩, prefix-stop 핀 시맨틱 등 핵심 설계가 견고합니다. 위 4건은 엣지 케이스 방어 보강 사항이며, 아키텍처 자체를 변경할 필요는 없습니다. 인라인 코멘트는 개별 리뷰 코멘트로 아래에 첨부합니다.

Comment thread maru_lmcache/adapter.py
shapes=self._shapes if len(self._shapes) > 1 else None,
dtypes=self._dtypes if len(self._dtypes) > 1 else None,
)
objs.append(TensorMemoryObj(tensor, metadata, parent_allocator=None))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] parent_allocator=None - AttributeError 위험

TensorMemoryObjparent_allocator=None으로 생성하고 있습니다. LMCache의 TensorMemoryObj.free() 또는 내부 로직이 self.parent_allocator.free(self)를 호출하면 AttributeError: NoneType object has no attribute free가 발생합니다.

현재는 CxlMemoryAdapter가 free()를 직접 관리하므로 문제가 없지만, LMCache 내부에서 MemoryObj lifecycle을 자동 관리하는 코드 경로가 추가되면 런타임 에러로 이어질 수 있습니다.

제안: parent_allocator=self를 전달하거나, 최소한 이 설계 결정에 대한 주석을 추가하여 향후 유지보수자가 인지할 수 있도록 하면 좋겠습니다.

Comment thread maru_lmcache/adapter.py
if region_id in self._pool:
return True

self._build_region_pool(region_id, page_count)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] ensure_region_pool TOCTOU 경합 조건

157-158행에서 double-checked locking으로 중복 빌드를 방지하고 있지만, 159행에서 lock을 해제한 뒤 160행에서 _build_region_pool()을 호출합니다. 두 스레드가 동시에 157행의 검사를 통과하면 같은 region에 대해 pool을 두 번 빌드할 수 있습니다.

_build_region_pool 내부에서 self._pool[region_id] = objs로 덮어쓰므로 데이터 손상은 아니지만, 불필요한 중복 작업이 발생합니다.

제안: lock을 유지한 채로 빌드하거나, _build_region_pool 진입 시 재확인 로직을 추가하면 됩니다:

def _build_region_pool(self, region_id, page_count):
    with self._lock:
        if region_id in self._pool:
            return

Comment thread maru_handler/handler.py
# Phase 2: Build register entries, free duplicates
for i, (key, handle) in enumerate(zip(keys, handles, strict=True)):
if key in self._key_to_location:
self._owned.free(handle._region_id, handle._page_index)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] batch_exists 실패 시 fallback이 페이지 누수를 증폭

806행에서 results[True] * len(keys)로 초기화한 뒤, batch_exists RPC가 실패하면 개별 exists_kv fallback으로 전환합니다. 하지만 fallback 중 일부 키에서 타임아웃이 발생하면 해당 인덱스는 초기값 True가 그대로 유지됩니다.

True는 캐시 히트를 의미하므로, 실제로는 존재하지 않는 키에 대해 페이지 할당을 건너뛰게 되어 데이터 불일치가 발생할 수 있습니다. 반대로 False로 초기화하면 불필요한 재저장만 발생하므로 안전합니다.

제안: results = [False] * len(keys)로 초기화하여 fail-safe 방향으로 변경하세요.

Comment thread maru_handler/handler.py

offset = page_index * chunk_size
register_entries.append((key, region_id, offset, total_size))
register_entries.append((key, region_id, offset, handle._size))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] batch_register 응답 길이 불일치 시 팬텀 엔트리

869-870행에서 batch_resp.results 길이가 요청한 엔트리 수보다 짧으면 누락된 인덱스의 결과가 처리되지 않습니다. is_new 검사를 건너뛰므로 중복 등록된 페이지가 해제되지 않아 메모리 누수가 발생합니다.

서버가 항상 동일 길이를 반환한다고 가정하고 있지만, 서버 버그나 프로토콜 불일치 시 silent failure가 됩니다.

제안: 응답 길이 검증을 추가하세요:

if len(batch_resp.results) != len(new_entries):
    logger.error("batch_register response length mismatch")

jooho-XCENA and others added 2 commits April 1, 2026 09:59
Resolve conflict in maru_handler/handler.py: keep feat/maru_backend's
handle-only store() signature, removing the allocate+memcpy path added
in main which is no longer needed with the new zero-copy architecture.
Hard assertion `speedup > 0.5` fails on fast local IPC where
run_coroutine_threadsafe scheduling overhead outweighs pipeline benefit.
Replace with a UserWarning so the test still validates correctness
(all futures succeed) without causing CI failures on low-latency setups.
@kihwan-XCENA

kihwan-XCENA commented Apr 2, 2026

Copy link
Copy Markdown
Collaborator
  • pytest
image
  • vllm + lmcache +maru running in Naru
  • Example running

@jooho-XCENA jooho-XCENA merged commit ece0311 into main Apr 3, 2026
3 checks passed
@jooho-XCENA jooho-XCENA deleted the feat/maru_backend branch April 3, 2026 06:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants