Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 188 additions & 21 deletions transfer_queue/storage/clients/mooncake_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import pickle
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any

Expand All @@ -33,8 +34,10 @@
except ImportError:
MOONCAKE_STORE_IMPORTED = False

BATCH_SIZE_LIMIT: int = 200
BATCH_SIZE_LIMIT: int = 400
MAX_WORKER_THREADS = 4
MAX_RETRIES = 3
RETRY_DELAY_SECONDS = 1.0
Comment on lines +37 to +40
Comment on lines +37 to +40


@StorageClientFactory.register("MooncakeStoreClient")
Expand Down Expand Up @@ -147,23 +150,96 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[

try:
results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config)
if not all(r == 0 for r in results):
failed_indices = [j for j, r in enumerate(results) if r != 0]
error_codes = [results[j] for j in failed_indices]
if len(results) != len(batch_keys):
raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}")

failed_indices = [j for j, r in enumerate(results) if r != 0]
if not failed_indices:
return

current_failed_keys = [batch_keys[i] for i in failed_indices]
current_failed_codes = [results[i] for i in failed_indices]
current_failed_indices = failed_indices

logger.error(
f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. "
f"Retrying up to {MAX_RETRIES} times..."
)

for attempt in range(1, MAX_RETRIES + 1):
retry_ptrs = [batch_ptrs[i] for i in current_failed_indices]
retry_sizes = [batch_sizes[i] for i in current_failed_indices]

retry_results = self._store.batch_upsert_from(
current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config
)

next_failed_indices = []
next_failed_keys = []
next_failed_codes = []

for i, ret in enumerate(retry_results):
if ret != 0:
next_failed_indices.append(current_failed_indices[i])
next_failed_keys.append(current_failed_keys[i])
next_failed_codes.append(ret)

if not next_failed_indices:
logger.info("batch_upsert_from succeeded after retransmission.")
break # All retries in this attempt succeeded.

logger.error(
f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys "
f"with error codes {next_failed_codes}."
)

current_failed_indices = next_failed_indices
current_failed_keys = next_failed_keys
current_failed_codes = next_failed_codes

if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY_SECONDS)
else:
raise RuntimeError(
f"batch_upsert_from failed for indices {failed_indices} with error codes: {error_codes}"
f"batch_upsert_from failed for keys {current_failed_keys} with error codes "
f"{current_failed_codes} after retrying {MAX_RETRIES} times."
)

finally:
self._unregister_all_buffers(batch_ptr_reduced)

def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]):
"""Worker thread for putting batch of non-tensors to MooncakeStore."""

batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]
serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]

ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config)
if ret != 0:
raise RuntimeError(f"upsert_batch failed with error code: {ret}")
# FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch,
# switch the bytes write/read paths from whole-batch retry to per-key selective retry,
# matching the tensor-path behaviour.
ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config)
if ret == 0:
return

logger.error(
f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret}. "
f"Retrying up to {MAX_RETRIES} times..."
)

for attempt in range(1, MAX_RETRIES + 1):
ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config)
if ret == 0:
logger.info("upsert_batch succeeded after retransmission.")
return

logger.error(
f"upsert_batch retry {attempt}/{MAX_RETRIES} failed for {len(batch_keys)} keys with error code: {ret}."
)
if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY_SECONDS)

raise RuntimeError(
f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret} after retrying {MAX_RETRIES} times."
)

def get(
self,
Expand Down Expand Up @@ -238,25 +314,116 @@ def _get_tensors_thread_worker(
ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes)
if len(ret_codes) != len(batch_keys):
raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}")
for i, ret in enumerate(ret_codes):
if ret < 0:
raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}")

failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0]
if not failed_indices:
return batch_buffer_tensors, indexes

# error handling
current_failed_keys = [batch_keys[i] for i in failed_indices]
current_failed_codes = [ret_codes[i] for i in failed_indices]
current_failed_indices = failed_indices

logger.error(
f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. "
f"Retrying up to {MAX_RETRIES} times..."
)

for attempt in range(1, MAX_RETRIES + 1):
# Reuse the originally allocated pointers; no need to allocate/register new buffers.
retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices]
retry_nbytes = [batch_nbytes[i] for i in current_failed_indices]

retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes)

next_failed_indices = []
next_failed_keys = []
next_failed_codes = []

for i, ret in enumerate(retry_codes):
if ret < 0:
next_failed_indices.append(current_failed_indices[i])
next_failed_keys.append(current_failed_keys[i])
next_failed_codes.append(ret)

Comment on lines +332 to +348
if not next_failed_indices:
logger.info("batch_get_into succeeded after retransmission.")
break # All retries in this attempt succeeded.

logger.error(
f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys "
f"with error codes {next_failed_codes}."
)

# Narrow down to still-failed items for the next retry attempt.
current_failed_indices = next_failed_indices
current_failed_keys = next_failed_keys
current_failed_codes = next_failed_codes

if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY_SECONDS)
else:
# All retries exhausted.
raise RuntimeError(
f"batch_get_into failed for keys {current_failed_keys} with error codes "
f"{current_failed_codes} after retrying {MAX_RETRIES} times."
)

finally:
self._unregister_all_buffers(region_ptrs)

return batch_buffer_tensors, indexes

def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]:
results = []

batch_results = self._store.get_batch(batch_keys)
if len(batch_results) != len(batch_keys):
raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}")

batch_results = [pickle.loads(result) if result != b"" else None for result in batch_results]
results.extend(batch_results)
raw_results = self._store.get_batch(batch_keys)
if len(raw_results) != len(batch_keys):
raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}")

# FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported
# Currently we rely on empty bytes (b'') to detect transmission failures because
# MooncakeStore does not currently return a separate status code per key.
failed_indices = [i for i, result in enumerate(raw_results) if result == b""]
if failed_indices:
current_failed_keys = [batch_keys[i] for i in failed_indices]
current_failed_indices = failed_indices

logger.error(f"get_batch failed for keys {current_failed_keys}. Retrying up to {MAX_RETRIES} times...")

for attempt in range(1, MAX_RETRIES + 1):
retry_results = self._store.get_batch(current_failed_keys)

next_failed_keys = []
next_failed_indices = []

for i, result in enumerate(retry_results):
original_idx = current_failed_indices[i]
if result == b"":
Comment on lines +392 to +400
next_failed_keys.append(current_failed_keys[i])
next_failed_indices.append(original_idx)
else:
# Write the successfully retried value back to its original slot immediately.
raw_results[original_idx] = result

Comment on lines +392 to +406
if not next_failed_indices:
logger.info("get_batch succeeded after retransmission.")
break # All retries in this attempt succeeded.

logger.error(f"get_batch retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys.")

# Narrow down to still-failed items for the next retry attempt.
current_failed_keys = next_failed_keys
current_failed_indices = next_failed_indices

if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY_SECONDS)
else:
# All retries exhausted.
raise RuntimeError(
f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times."
)
Comment on lines +419 to +423

return results, indexes
deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results]
return deserialized_results, indexes
Comment on lines +419 to +426

def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None:
"""Deletes multiple keys from MooncakeStore.
Expand Down
Loading