From 29b63af0dc180b7ca032df5ca50ee1ddb35261fa Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 11 May 2026 17:02:50 +0800 Subject: [PATCH 1/3] add retransmission Signed-off-by: 0oshowero0 --- .../storage/clients/mooncake_client.py | 113 ++++++++++++++++-- 1 file changed, 100 insertions(+), 13 deletions(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 28706a3..7c88d37 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -14,6 +14,7 @@ # limitations under the License. import pickle +import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -33,8 +34,10 @@ except ImportError: MOONCAKE_STORE_IMPORTED = False -BATCH_SIZE_LIMIT: int = 200 +BATCH_SIZE_LIMIT: int = 400 MAX_WORKER_THREADS = 4 +MAX_RETRIES = 3 +RETRY_DELAY_SECONDS = 1.0 @StorageClientFactory.register("MooncakeStoreClient") @@ -238,25 +241,109 @@ def _get_tensors_thread_worker( ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) if len(ret_codes) != len(batch_keys): raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") - for i, ret in enumerate(ret_codes): - if ret < 0: - raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}") + + failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0] + if not failed_indices: + return batch_buffer_tensors, indexes + + # error handling + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [ret_codes[i] for i in failed_indices] + current_failed_indices = failed_indices + + logger.error( + f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) + + for attempt in range(1, MAX_RETRIES + 1): + # Reuse the originally allocated pointers; no need to allocate/register new buffers. + retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices] + retry_nbytes = [batch_nbytes[i] for i in current_failed_indices] + + retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes) + + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] + + for i, ret in enumerate(retry_codes): + if ret < 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) + + if not next_failed_indices: + break # All retries in this attempt succeeded. + + # Narrow down to still-failed items for the next retry attempt. + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + else: + # All retries exhausted. + raise RuntimeError( + f"batch_get_into failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." + ) + finally: self._unregister_all_buffers(region_ptrs) return batch_buffer_tensors, indexes def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]: - results = [] - - batch_results = self._store.get_batch(batch_keys) - if len(batch_results) != len(batch_keys): - raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}") - - batch_results = [pickle.loads(result) if result != b"" else None for result in batch_results] - results.extend(batch_results) + raw_results = self._store.get_batch(batch_keys) + if len(raw_results) != len(batch_keys): + raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}") + + # TODO: Use MooncakeStore provided ret codes to detect transmission failures when supported + # Currently we rely on empty bytes (b'') to detect transmission failures because + # MooncakeStore does not currently return a separate status code per key. + failed_indices = [i for i, result in enumerate(raw_results) if result == b""] + if failed_indices: + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_indices = failed_indices + + logger.error(f"get_batch failed for keys {current_failed_keys}. Retrying up to {MAX_RETRIES} times...") + + for attempt in range(1, MAX_RETRIES + 1): + retry_results = self._store.get_batch(current_failed_keys) + + next_failed_keys = [] + next_failed_indices = [] + + for i, result in enumerate(retry_results): + original_idx = current_failed_indices[i] + if result == b"": + next_failed_keys.append(current_failed_keys[i]) + next_failed_indices.append(original_idx) + else: + # Write the successfully retried value back to its original slot immediately. + raw_results[original_idx] = result + + if not next_failed_indices: + break # All retries in this attempt succeeded. + + # Narrow down to still-failed items for the next retry attempt. + current_failed_keys = next_failed_keys + current_failed_indices = next_failed_indices + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + else: + # All retries exhausted. + # FIXME: raise error here when we can distinguish transmission failures from empty values + logger.error( + f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times. " + f"Please validate if the values corresponding to these keys are `None` during put." + ) - return results, indexes + deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results] + return deserialized_results, indexes def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: """Deletes multiple keys from MooncakeStore. From 191e82322020ef22566440aeab8504c219259b46 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 12 May 2026 10:05:25 +0800 Subject: [PATCH 2/3] add retry for put Signed-off-by: 0oshowero0 --- .../storage/clients/mooncake_client.py | 104 +++++++++++++++--- 1 file changed, 91 insertions(+), 13 deletions(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 7c88d37..d5e0a4c 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -150,23 +150,94 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[ try: results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) - if not all(r == 0 for r in results): - failed_indices = [j for j, r in enumerate(results) if r != 0] - error_codes = [results[j] for j in failed_indices] + if len(results) != len(batch_keys): + raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}") + + failed_indices = [j for j, r in enumerate(results) if r != 0] + if not failed_indices: + return + + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [results[i] for i in failed_indices] + current_failed_indices = failed_indices + + logger.error( + f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) + + for attempt in range(1, MAX_RETRIES + 1): + retry_ptrs = [batch_ptrs[i] for i in current_failed_indices] + retry_sizes = [batch_sizes[i] for i in current_failed_indices] + + retry_results = self._store.batch_upsert_from( + current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config + ) + + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] + + for i, ret in enumerate(retry_results): + if ret != 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) + + if not next_failed_indices: + logger.info("batch_upsert_from succeeded after retransmission.") + break # All retries in this attempt succeeded. + + logger.error( + f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) + + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + else: raise RuntimeError( - f"batch_upsert_from failed for indices {failed_indices} with error codes: {error_codes}" + f"batch_upsert_from failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." ) + finally: self._unregister_all_buffers(batch_ptr_reduced) def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]): """Worker thread for putting batch of non-tensors to MooncakeStore.""" - batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] + serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] - ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config) - if ret != 0: - raise RuntimeError(f"upsert_batch failed with error code: {ret}") + # FIXME: Use element-level ret value to precise retransmit when MooncakeStore supports + ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) + if ret == 0: + return + + logger.error( + f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret}. " + f"Retrying up to {MAX_RETRIES} times..." + ) + + for attempt in range(1, MAX_RETRIES + 1): + ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) + if ret == 0: + logger.info("upsert_batch succeeded after retransmission.") + return + + logger.error( + f"upsert_batch retry {attempt}/{MAX_RETRIES} failed for {len(batch_keys)} keys with error code: {ret}." + ) + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + + raise RuntimeError( + f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret} after retrying {MAX_RETRIES} times." + ) def get( self, @@ -274,8 +345,14 @@ def _get_tensors_thread_worker( next_failed_codes.append(ret) if not next_failed_indices: + logger.info("batch_get_into succeeded after retransmission.") break # All retries in this attempt succeeded. + logger.error( + f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) + # Narrow down to still-failed items for the next retry attempt. current_failed_indices = next_failed_indices current_failed_keys = next_failed_keys @@ -300,7 +377,7 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> if len(raw_results) != len(batch_keys): raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}") - # TODO: Use MooncakeStore provided ret codes to detect transmission failures when supported + # FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported # Currently we rely on empty bytes (b'') to detect transmission failures because # MooncakeStore does not currently return a separate status code per key. failed_indices = [i for i, result in enumerate(raw_results) if result == b""] @@ -326,8 +403,11 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> raw_results[original_idx] = result if not next_failed_indices: + logger.info("get_batch succeeded after retransmission.") break # All retries in this attempt succeeded. + logger.error(f"get_batch retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys.") + # Narrow down to still-failed items for the next retry attempt. current_failed_keys = next_failed_keys current_failed_indices = next_failed_indices @@ -336,10 +416,8 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> time.sleep(RETRY_DELAY_SECONDS) else: # All retries exhausted. - # FIXME: raise error here when we can distinguish transmission failures from empty values - logger.error( - f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times. " - f"Please validate if the values corresponding to these keys are `None` during put." + raise RuntimeError( + f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times." ) deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results] From eb7f2d153a6117603bd044f4e56f335d12d8187b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 12 May 2026 10:14:00 +0800 Subject: [PATCH 3/3] fix comments Signed-off-by: 0oshowero0 --- transfer_queue/storage/clients/mooncake_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index d5e0a4c..ede1de4 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -213,7 +213,9 @@ def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] - # FIXME: Use element-level ret value to precise retransmit when MooncakeStore supports + # FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch, + # switch the bytes write/read paths from whole-batch retry to per-key selective retry, + # matching the tensor-path behaviour. ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) if ret == 0: return