diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 28706a3..ede1de4 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -14,6 +14,7 @@ # limitations under the License. import pickle +import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -33,8 +34,10 @@ except ImportError: MOONCAKE_STORE_IMPORTED = False -BATCH_SIZE_LIMIT: int = 200 +BATCH_SIZE_LIMIT: int = 400 MAX_WORKER_THREADS = 4 +MAX_RETRIES = 3 +RETRY_DELAY_SECONDS = 1.0 @StorageClientFactory.register("MooncakeStoreClient") @@ -147,23 +150,96 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[ try: results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) - if not all(r == 0 for r in results): - failed_indices = [j for j, r in enumerate(results) if r != 0] - error_codes = [results[j] for j in failed_indices] + if len(results) != len(batch_keys): + raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}") + + failed_indices = [j for j, r in enumerate(results) if r != 0] + if not failed_indices: + return + + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [results[i] for i in failed_indices] + current_failed_indices = failed_indices + + logger.error( + f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) + + for attempt in range(1, MAX_RETRIES + 1): + retry_ptrs = [batch_ptrs[i] for i in current_failed_indices] + retry_sizes = [batch_sizes[i] for i in current_failed_indices] + + retry_results = self._store.batch_upsert_from( + current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config + ) + + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] + + for i, ret in enumerate(retry_results): + if ret != 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) + + if not next_failed_indices: + logger.info("batch_upsert_from succeeded after retransmission.") + break # All retries in this attempt succeeded. + + logger.error( + f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) + + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + else: raise RuntimeError( - f"batch_upsert_from failed for indices {failed_indices} with error codes: {error_codes}" + f"batch_upsert_from failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." ) + finally: self._unregister_all_buffers(batch_ptr_reduced) def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]): """Worker thread for putting batch of non-tensors to MooncakeStore.""" - batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] + serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] - ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config) - if ret != 0: - raise RuntimeError(f"upsert_batch failed with error code: {ret}") + # FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch, + # switch the bytes write/read paths from whole-batch retry to per-key selective retry, + # matching the tensor-path behaviour. + ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) + if ret == 0: + return + + logger.error( + f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret}. " + f"Retrying up to {MAX_RETRIES} times..." + ) + + for attempt in range(1, MAX_RETRIES + 1): + ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) + if ret == 0: + logger.info("upsert_batch succeeded after retransmission.") + return + + logger.error( + f"upsert_batch retry {attempt}/{MAX_RETRIES} failed for {len(batch_keys)} keys with error code: {ret}." + ) + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + + raise RuntimeError( + f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret} after retrying {MAX_RETRIES} times." + ) def get( self, @@ -238,25 +314,116 @@ def _get_tensors_thread_worker( ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) if len(ret_codes) != len(batch_keys): raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") - for i, ret in enumerate(ret_codes): - if ret < 0: - raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}") + + failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0] + if not failed_indices: + return batch_buffer_tensors, indexes + + # error handling + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [ret_codes[i] for i in failed_indices] + current_failed_indices = failed_indices + + logger.error( + f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) + + for attempt in range(1, MAX_RETRIES + 1): + # Reuse the originally allocated pointers; no need to allocate/register new buffers. + retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices] + retry_nbytes = [batch_nbytes[i] for i in current_failed_indices] + + retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes) + + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] + + for i, ret in enumerate(retry_codes): + if ret < 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) + + if not next_failed_indices: + logger.info("batch_get_into succeeded after retransmission.") + break # All retries in this attempt succeeded. + + logger.error( + f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) + + # Narrow down to still-failed items for the next retry attempt. + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + else: + # All retries exhausted. + raise RuntimeError( + f"batch_get_into failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." + ) + finally: self._unregister_all_buffers(region_ptrs) return batch_buffer_tensors, indexes def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]: - results = [] - - batch_results = self._store.get_batch(batch_keys) - if len(batch_results) != len(batch_keys): - raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}") - - batch_results = [pickle.loads(result) if result != b"" else None for result in batch_results] - results.extend(batch_results) + raw_results = self._store.get_batch(batch_keys) + if len(raw_results) != len(batch_keys): + raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}") + + # FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported + # Currently we rely on empty bytes (b'') to detect transmission failures because + # MooncakeStore does not currently return a separate status code per key. + failed_indices = [i for i, result in enumerate(raw_results) if result == b""] + if failed_indices: + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_indices = failed_indices + + logger.error(f"get_batch failed for keys {current_failed_keys}. Retrying up to {MAX_RETRIES} times...") + + for attempt in range(1, MAX_RETRIES + 1): + retry_results = self._store.get_batch(current_failed_keys) + + next_failed_keys = [] + next_failed_indices = [] + + for i, result in enumerate(retry_results): + original_idx = current_failed_indices[i] + if result == b"": + next_failed_keys.append(current_failed_keys[i]) + next_failed_indices.append(original_idx) + else: + # Write the successfully retried value back to its original slot immediately. + raw_results[original_idx] = result + + if not next_failed_indices: + logger.info("get_batch succeeded after retransmission.") + break # All retries in this attempt succeeded. + + logger.error(f"get_batch retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys.") + + # Narrow down to still-failed items for the next retry attempt. + current_failed_keys = next_failed_keys + current_failed_indices = next_failed_indices + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + else: + # All retries exhausted. + raise RuntimeError( + f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times." + ) - return results, indexes + deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results] + return deserialized_results, indexes def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: """Deletes multiple keys from MooncakeStore.