From a78525570ac6e2188f037ea93b4fab9de3ca2dd4 Mon Sep 17 00:00:00 2001 From: kwanUm Date: Sun, 27 Oct 2024 14:14:50 +0000 Subject: [PATCH 1/3] nits --- ldp/alg/rollout.py | 8 +- ldp/graph/async_torch.py | 199 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 2 deletions(-) diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index eec32ae5..27b748b9 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -194,12 +194,16 @@ async def _sample_trajectories_from_envs( self.traj_buffer.clear() traj_ids = [uuid.uuid4().hex for _ in range(len(environments))] - await asyncio.gather( + from tqdm.asyncio import tqdm_asyncio + + await tqdm_asyncio.gather( *( self._rollout(*args, max_steps=max_steps) for args in zip(traj_ids, environments, strict=True) - ) + ), + desc="Sampling trajectories" ) + return [self.traj_buffer[traj_id] for traj_id in traj_ids] async def _rollout( diff --git a/ldp/graph/async_torch.py b/ldp/graph/async_torch.py index c8e39f6f..2bb9d4f6 100644 --- a/ldp/graph/async_torch.py +++ b/ldp/graph/async_torch.py @@ -146,6 +146,205 @@ async def _maybe_process_batch(self): @abstractmethod async def _batched_call(self, batch_kwargs: dict[str, Any]): """Logic to call the worker on a batch of inputs.""" + + + +class AsyncBufferedWorker(ABC): + """Abstract class for a worker that buffers inputs and processes them in batches.""" + + def __init__( + self, + batch_size: int, + max_wait_interval: float, + collate_fn: Callable = lambda x: x, + decollate_fn: Callable = lambda x: x, + ): + """Initialize. + + Args: + batch_size: The target batch size to use when calling the module. As soon as + batch_size calls are made, a forward pass is executed. + max_wait_interval: The maximum time to wait for a batch to fill up before + executing the calls we have buffered. + collate_fn: A function to pre-process a list of inputs into a batch. Defaults to a + no-op. + decollate_fn: Kind of like the opposite of collate_fn. This function should take + the batched output and return an ordered list of outputs. Defaults to no-op. + """ + self.batch_size = batch_size + self.timeout = max_wait_interval + self.collate_fn = collate_fn + self.decollate_fn = decollate_fn + + self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] + self._result_buffer: dict[UUID, Any] = {} + self._lock = asyncio.Lock() + self._batch_ready_event = asyncio.Event() + self._processed_events = {} + self._counter = 0 + self._events_count = {} + + async def __call__(self, **kwargs): + request_id = uuid4() + request_ts = time.time() + + async with self._lock: + self._processed_events[request_id] = asyncio.Event() + self._events_count[request_id] = self._counter + self._counter += 1 + print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}") + self._work_buffer.append((request_ts, request_id, kwargs)) + + # If we've reached batch size, we trigger the processing event immediately + if len(self._work_buffer) >= self.batch_size: + self._batch_ready_event.set() + + try: + # Wait for either the batch to fill up or the timeout to expire + await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + pass + + await self._maybe_process_batch() + + await self._processed_events[request_id].wait() + + async with self._lock: + print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}") + self._events_count.pop(request_id) + self._processed_events.pop(request_id) + return self._result_buffer.pop(request_id) + + async def _maybe_process_batch(self): + """If the buffer is >= batch size or we have been waiting long enough, process the old batch. + + If neither condition is met, do nothing. + """ + async with self._lock: + # If there's at least one request in the buffer, we can process it + if len(self._work_buffer) == 0: + return + + self._work_buffer.sort(key=operator.itemgetter(0)) + + batch = self._work_buffer[: self.batch_size] + self._work_buffer = self._work_buffer[self.batch_size :] + + if len(self._work_buffer) < self.batch_size: + self._batch_ready_event.clear() + + # Construct the batch tensors + sample_kwargs = [x[2] for x in batch] + batch_kwargs = self.collate_fn(sample_kwargs) + + print(f"starting to wait for batched call, counter: {self._counter}") + batched_results = await self._batched_call(batch_kwargs) + print(f"finished waiting for batched call, counter: {self._counter}") + request_ids = [x[1] for x in batch] + results = self.decollate_fn(batched_results) + async with self._lock: + print(f"updating result buffer, counter: {self._counter}") + self._result_buffer.update(zip(request_ids, results, strict=True)) + for request_id in request_ids: + self._processed_events[request_id].set() + + def _process_batch(self): + """Processes the current batch.""" + + + @abstractmethod + async def _batched_call(self, batch_kwargs: dict[str, Any]): + """Logic to call the worker on a batch of inputs.""" + + + + +class AsyncBufferedWorker(ABC): + def __init__(self, batch_size, max_wait_interval, collate_fn=lambda x: x, decollate_fn=lambda x: x): + self.batch_size = batch_size + self.timeout = max_wait_interval + self.collate_fn = collate_fn + self.decollate_fn = decollate_fn + + self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] + self._result_buffer: dict[UUID, Any] = {} + self._lock = asyncio.Lock() + self._new_data_event = asyncio.Event() + + self._processed_events = {} + self._counter = 0 + self._events_count = {} + + # Start the background batch processing task + self._batch_processing_task = asyncio.create_task(self._batch_processor()) + + async def __call__(self, **kwargs): + request_id = uuid4() + request_ts = time.time() + + async with self._lock: + self._processed_events[request_id] = asyncio.Event() + self._events_count[request_id] = self._counter + self._counter += 1 + print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}") + self._work_buffer.append((request_ts, request_id, kwargs)) + if len(self._work_buffer) >= self.batch_size: + self._new_data_event.set() # Signal that new data has arrived + print(f"did set new data event, counter: {self._counter}") + + # Wait for the result to be processed + await self._processed_events[request_id].wait() + + async with self._lock: + print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}") + self._events_count.pop(request_id) + self._processed_events.pop(request_id) + return self._result_buffer.pop(request_id) + + async def _batch_processor(self): + while True: + try: + # Wait for new data or timeout + await asyncio.wait_for(self._new_data_event.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + pass + + async with self._lock: + if len(self._work_buffer) == 0: + self._new_data_event.clear() + continue + + now = time.time() + # Sort the work buffer by timestamp to maintain order + self._work_buffer.sort(key=operator.itemgetter(0)) + + batch = self._work_buffer[:self.batch_size] + self._work_buffer = self._work_buffer[self.batch_size:] + if len(self._work_buffer) == 0: + self._new_data_event.clear() + + # Process the batch outside the lock + sample_kwargs = [x[2] for x in batch] + batch_kwargs = self.collate_fn(sample_kwargs) + print(f"Starting batched call, counter: {self._counter}, batch size: {len(batch)}") + batched_results = await self._batched_call(batch_kwargs) + print(f"Finished batched call, counter: {self._counter}") + request_ids = [x[1] for x in batch] + results = self.decollate_fn(batched_results) + async with self._lock: + print(f"Updating result buffer, counter: {self._counter}") + self._result_buffer.update(zip(request_ids, results)) + for request_id in request_ids: + self._processed_events[request_id].set() + + # Let other requests proceed as soon as their result is available + await asyncio.sleep(0.0) + + @abstractmethod + async def _batched_call(self, batch_kwargs: dict[str, Any]): + """Logic to call the worker on a batch of inputs.""" + pass + class AsyncTorchModule(AsyncBufferedWorker): From ca25c08e28b47b41a7bd5560fabbb5f5044acc7e Mon Sep 17 00:00:00 2001 From: kwanUm Date: Sun, 3 Nov 2024 10:51:25 +0000 Subject: [PATCH 2/3] nits --- ldp/graph/async_torch.py | 148 +++++++++++++++++++++++++-------------- 1 file changed, 97 insertions(+), 51 deletions(-) diff --git a/ldp/graph/async_torch.py b/ldp/graph/async_torch.py index ea420545..871feb0b 100644 --- a/ldp/graph/async_torch.py +++ b/ldp/graph/async_torch.py @@ -3,6 +3,7 @@ import asyncio import operator import time +import logging from abc import ABC, abstractmethod from collections.abc import Callable from contextlib import nullcontext @@ -18,6 +19,9 @@ "ldp.graph.async_torch requires PyTorch as a dependency. " "Please run `pip install ldp[nn]`." ) from None + + +logger = logging.getLogger(__name__) _TORCH_LOCK = asyncio.Lock() @@ -149,7 +153,7 @@ async def _batched_call(self, batch_kwargs: dict[str, Any]): -class AsyncBufferedWorker(ABC): +class AsyncBufferedWorker2(ABC): """Abstract class for a worker that buffers inputs and processes them in batches.""" def __init__( @@ -255,12 +259,16 @@ def _process_batch(self): @abstractmethod async def _batched_call(self, batch_kwargs: dict[str, Any]): """Logic to call the worker on a batch of inputs.""" - - - - -class AsyncBufferedWorker(ABC): - def __init__(self, batch_size, max_wait_interval, collate_fn=lambda x: x, decollate_fn=lambda x: x): + + +class AsyncBufferedWorker2(ABC): + def __init__( + self, + batch_size: int, + max_wait_interval: float, + collate_fn: Callable = lambda x: x, + decollate_fn: Callable = lambda x: x, + ): self.batch_size = batch_size self.timeout = max_wait_interval self.collate_fn = collate_fn @@ -271,82 +279,120 @@ def __init__(self, batch_size, max_wait_interval, collate_fn=lambda x: x, decoll self._lock = asyncio.Lock() self._new_data_event = asyncio.Event() - self._processed_events = {} + self._processed_events: dict[UUID, asyncio.Event] = {} self._counter = 0 - self._events_count = {} + self._events_count: dict[UUID, int] = {} # Just for debugging and printing the order of requests + self._exception: Exception | None = None # Store exception from _batch_processor # Start the background batch processing task self._batch_processing_task = asyncio.create_task(self._batch_processor()) + self._batch_processing_task.add_done_callback(self._handle_task_exception) async def __call__(self, **kwargs): request_id = uuid4() request_ts = time.time() - + async with self._lock: + if self._exception is not None: + # If an exception has occurred, raise it immediately + raise self._exception self._processed_events[request_id] = asyncio.Event() self._events_count[request_id] = self._counter self._counter += 1 - print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}") self._work_buffer.append((request_ts, request_id, kwargs)) if len(self._work_buffer) >= self.batch_size: self._new_data_event.set() # Signal that new data has arrived - print(f"did set new data event, counter: {self._counter}") - # Wait for the result to be processed + # Wait for the result to be processed or an exception to occur await self._processed_events[request_id].wait() - + async with self._lock: - print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}") self._events_count.pop(request_id) self._processed_events.pop(request_id) - return self._result_buffer.pop(request_id) + if self._exception is not None: + # If an exception occurred during processing, raise it here + raise self._exception + elif request_id in self._result_buffer: + return self._result_buffer.pop(request_id) + else: + # Should not happen, but handle just in case + raise RuntimeError("Result not available and no exception set.") async def _batch_processor(self): - while True: - try: - # Wait for new data or timeout - await asyncio.wait_for(self._new_data_event.wait(), timeout=self.timeout) - except asyncio.TimeoutError: - pass - + try: + while True: + try: + # Wait for new data or timeout + await asyncio.wait_for(self._new_data_event.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + pass + + async with self._lock: + if len(self._work_buffer) == 0: + self._new_data_event.clear() + continue + + # Sort the work buffer by timestamp to maintain order + self._work_buffer.sort(key=operator.itemgetter(0)) + + batch = self._work_buffer[:self.batch_size] + self._work_buffer = self._work_buffer[self.batch_size:] + if len(self._work_buffer) == 0: + self._new_data_event.clear() + + # Process the batch outside the lock + sample_kwargs = [x[2] for x in batch] + batch_kwargs = self.collate_fn(sample_kwargs) + batched_results = await self._batched_call(batch_kwargs) + request_ids = [x[1] for x in batch] + results = self.decollate_fn(batched_results) + async with self._lock: + self._result_buffer.update(zip(request_ids, results)) + for request_id in request_ids: + self._processed_events[request_id].set() + + # Let other requests proceed as soon as their result is available + await asyncio.sleep(0) + except asyncio.CancelledError: + pass # Allow task to exit gracefully + except Exception as e: + logger.error(f"Exception in _batch_processor: {e}", exc_info=True) + # Store the exception async with self._lock: - if len(self._work_buffer) == 0: - self._new_data_event.clear() - continue + self._exception = e + # Notify all pending requests about the exception + for event in self._processed_events.values(): + event.set() - now = time.time() - # Sort the work buffer by timestamp to maintain order - self._work_buffer.sort(key=operator.itemgetter(0)) + def _handle_task_exception(self, task): + try: + task.result() + except asyncio.CancelledError: + # Task was cancelled, nothing to do + pass + except Exception as e: + # Already handled in _batch_processor + pass - batch = self._work_buffer[:self.batch_size] - self._work_buffer = self._work_buffer[self.batch_size:] - if len(self._work_buffer) == 0: - self._new_data_event.clear() + async def close(self): + self._batch_processing_task.cancel() + try: + await self._batch_processing_task + except asyncio.CancelledError: + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() - # Process the batch outside the lock - sample_kwargs = [x[2] for x in batch] - batch_kwargs = self.collate_fn(sample_kwargs) - print(f"Starting batched call, counter: {self._counter}, batch size: {len(batch)}") - batched_results = await self._batched_call(batch_kwargs) - print(f"Finished batched call, counter: {self._counter}") - request_ids = [x[1] for x in batch] - results = self.decollate_fn(batched_results) - async with self._lock: - print(f"Updating result buffer, counter: {self._counter}") - self._result_buffer.update(zip(request_ids, results)) - for request_id in request_ids: - self._processed_events[request_id].set() - - # Let other requests proceed as soon as their result is available - await asyncio.sleep(0.0) - @abstractmethod async def _batched_call(self, batch_kwargs: dict[str, Any]): """Logic to call the worker on a batch of inputs.""" pass - class AsyncTorchModule(AsyncBufferedWorker): def __init__( self, From 8e66c86879b692f32710236d35b8bf50e62f64b7 Mon Sep 17 00:00:00 2001 From: kwanUm Date: Mon, 4 Nov 2024 13:15:30 +0000 Subject: [PATCH 3/3] nits --- ldp/alg/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index 6d29e6ab..61ebdf0a 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -2,6 +2,7 @@ import itertools import logging import uuid +from tqdm.asyncio import tqdm_asyncio from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager, nullcontext from typing import Any, TypeVar, overload @@ -193,7 +194,6 @@ async def _sample_trajectories_from_envs( self.traj_buffer.clear() traj_ids = [uuid.uuid4().hex for _ in range(len(environments))] - from tqdm.asyncio import tqdm_asyncio await tqdm_asyncio.gather( *(