diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index ef54a593920b..999e2e10af9b 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -1,5 +1,9 @@ from typing import Callable, Iterator, Optional, TypeVar +from ray.data._internal.block_batching.interfaces import ( + BatchTimings, + BlockWithTiming, +) from ray.data._internal.block_batching.util import ( _MappingIterator, blocks_to_batches, @@ -29,10 +33,18 @@ def batch_blocks( This function takes in an iterator of already fetched blocks. Consequently, this function doesn't support block prefetching. """ + # Wrap raw blocks in BlockWithTiming with zero timing so that + # _BatchingIterator receives a uniform type. Use map() instead of a + # generator expression to avoid holding references to blocks. + def _wrap_block(b): + return BlockWithTiming(block=b, timings=BatchTimings()) + + wrapped_blocks = map(_wrap_block, blocks) + # Build the processing pipeline batch_iter = format_batches( blocks_to_batches( - block_iter=blocks, + block_iter=wrapped_blocks, stats=stats, batch_size=batch_size, drop_last=drop_last, diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 4f0bed6b3dd4..cfdd21e918a3 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,11 +1,115 @@ import abc -from dataclasses import dataclass -from typing import Any, List +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Iterable, List, Tuple from ray.data.block import Block, DataBatch from ray.types import ObjectRef +@dataclass +class StageTiming: + """Wall-clock window for a single batch-processing stage. + + Can be used as a context manager to automatically capture the start and + end timestamps of a pipeline operation:: + + with stage_timing: + do_work() + # stage_timing.start_s and stage_timing.end_s are now set + """ + + start_s: float = 0.0 + end_s: float = 0.0 + + def __enter__(self): + self.start_s = time.perf_counter() + return self + + def __exit__(self, *args): + self.end_s = time.perf_counter() + + @contextmanager + def timer(self): + """Alias for using as a context manager, matching Timer.timer() API.""" + self.start_s = time.perf_counter() + try: + yield + finally: + self.end_s = time.perf_counter() + + +@dataclass +class BatchTimings: + """Per-batch pipeline-stage timing windows for overlap-based attribution. + + Each field records the ``(start_s, end_s)`` wall-clock window during which + a particular pipeline stage was active for this batch. The training thread + later compares these windows against its own blocked window to determine + how much each stage contributed to training-thread stall (see + :meth:`BatchIterator._report_batch_timings`). + + Attributes: + fetch: Waiting for upstream data production + ``ray.get()`` transfer. + batching: Assembling blocks into a batch via ``_batcher.next_batch()``. + format: Converting the batch to the requested format (numpy, pandas…). + collate: Running the user-provided ``collate_fn``. + finalize: Running the user-provided ``finalize_fn`` (e.g. host→device). + num_rows: Number of rows in this batch (for ``iter_rows_total``). + """ + + fetch: StageTiming = field(default_factory=StageTiming) + batching: StageTiming = field(default_factory=StageTiming) + format: StageTiming = field(default_factory=StageTiming) + collate: StageTiming = field(default_factory=StageTiming) + finalize: StageTiming = field(default_factory=StageTiming) + num_rows: int = 0 + + def stages(self) -> Iterable[Tuple[str, StageTiming]]: + """Iterate over ``(name, timing)`` pairs for all pipeline stages.""" + return ( + ("fetch", self.fetch), + ("batching", self.batching), + ("format", self.format), + ("collate", self.collate), + ("finalize", self.finalize), + ) + + def merge_fetch(self, other: "BatchTimings") -> None: + """Merge fetch timings from another batch into this one. + + Expands the fetch window to span from the earliest block fetch start + to the latest block fetch end. This represents the total time the + training thread was blocked waiting for this batch, including any + pipeline overhead between consecutive block fetches. + """ + if other.fetch.start_s == 0.0: + return + if self.fetch.start_s == 0.0: + # First block: copy the timing + self.fetch.start_s = other.fetch.start_s + self.fetch.end_s = other.fetch.end_s + else: + # Subsequent blocks: expand the window + if other.fetch.start_s < self.fetch.start_s: + self.fetch.start_s = other.fetch.start_s + if other.fetch.end_s > self.fetch.end_s: + self.fetch.end_s = other.fetch.end_s + + +@dataclass +class BlockWithTiming: + """A resolved block paired with its fetch timing window. + + Produced by :func:`resolve_block_refs` so that downstream pipeline stages + can track how long each block took to fetch (upstream wait + ``ray.get()``). + """ + + block: Block + timings: BatchTimings = field(default_factory=BatchTimings) + + @dataclass class BatchMetadata: """Metadata associated with a batch. @@ -13,9 +117,11 @@ class BatchMetadata: Attributes: batch_idx: The global index of this batch so that downstream operations can maintain ordering. + timings: Pipeline-stage timing windows for this batch. """ batch_idx: int + timings: BatchTimings = field(default_factory=BatchTimings) @dataclass diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index f9bf0076d2af..dc669c9725fd 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -175,7 +175,7 @@ def _prefetch_blocks( def _resolve_block_refs( self, block_refs: Iterator[ObjectRef[Block]] - ) -> Iterator[Block]: + ) -> Iterator[Any]: return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats) def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]: @@ -249,15 +249,66 @@ def _iter_batches(self) -> Iterator[DataBatch]: while True: with self.get_next_batch_context(): + blocked_start_s = time.perf_counter() try: batch = next(batch_iter) except StopIteration: break + blocked_end_s = time.perf_counter() + self._report_batch_timings(batch, blocked_start_s, blocked_end_s) with self.yield_batch_context(batch): yield batch.data self.after_epoch_end() + def _report_batch_timings( + self, batch: Batch, blocked_start_s: float, blocked_end_s: float + ) -> None: + """Attribute per-stage blocked time via overlap with the training window. + + For each pipeline stage we know when it ran ``[stage.start_s, + stage.end_s]`` (recorded by background threads onto + ``batch.metadata.timings``). We also know when the training thread + was blocked ``[blocked_start_s, blocked_end_s]`` (captured in + ``_iter_batches`` around ``next()``). + + The attribution for a stage is the length of the intersection:: + + overlap = min(stage.end, blocked_end) - max(stage.start, blocked_start) + + This correctly handles all prefetch configurations: + + * Stage finished before training blocked → overlap ≤ 0 → zero credit. + * Stage fully inside blocked window → full stage duration credited. + * Partial overlap → partial credit. + + **Invariant**: ``sum(iter_blocked_*) ≤ iter_total_blocked_s``. + + Runs in the training thread; no locks needed because background + threads finished writing ``batch.metadata.timings`` before the batch + was enqueued. + + Args: + batch: The batch whose per-stage timings should be attributed. + blocked_start_s: ``perf_counter()`` value just before the + training thread called ``next(batch_iter)``. + blocked_end_s: ``perf_counter()`` value just after ``next()`` + returned. + """ + if self._stats is None: + return + timings = batch.metadata.timings + for name, stage in timings.stages(): + if stage.start_s == 0.0 and stage.end_s == 0.0: + continue + overlap_s = min(stage.end_s, blocked_end_s) - max( + stage.start_s, blocked_start_s + ) + if overlap_s > 0: + getattr(self._stats, f"iter_blocked_{name}_s").add(overlap_s) + self._stats.iter_batches_total += 1 + self._stats.iter_rows_total += timings.num_rows + def __iter__(self) -> Iterator[DataBatch]: return self._iter_batches() diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 8a42cde7871e..fcc6cf2a40c2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -22,7 +22,9 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, BlockPrefetcher, + BlockWithTiming, CollatedBatch, ) from ray.data._internal.stats import DatasetStats @@ -173,31 +175,51 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[DatasetStats] = None, -) -> Iterator[Block]: +) -> Iterator[BlockWithTiming]: """Resolves the block references for each logical batch. + Each resolved block is wrapped in a :class:`BlockWithTiming` that carries + the per-block fetch window. The fetch window spans from the moment we + start waiting for the upstream iterator (blocked on the data pipeline or + cross-node transfer) until ``ray.get()`` returns the resolved block. + When *stats* is provided, the cumulative fetch time is also recorded in + ``stats.iter_get_s``. + Args: block_ref_iter: An iterator over block object references. - stats: An optional stats object to recording block hits and misses. + stats: An optional stats object to record block hits, misses, and + cumulative fetch time. Yields: - Block: The resolved blocks for each block reference. + BlockWithTiming: Each resolved block with its fetch timing window. """ hits = 0 misses = 0 unknowns = 0 - for block_ref in block_ref_iter: - current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) - hits += current_hit - misses += current_miss - unknowns += current_unknown + while True: + # Time the upstream pull — captures blocked time waiting for the + # data pipeline to produce the next block ref. + timings = BatchTimings() + with timings.fetch: + try: + block_ref = next(block_ref_iter) + except StopIteration: + break - # TODO(amogkam): Optimized further by batching multiple references in a single - # `ray.get()` call. - with stats.iter_get_s.timer() if stats else nullcontext(): - block = ray.get(block_ref) - yield block + current_hit, current_miss, current_unknown = _calculate_ref_hits( + [block_ref] + ) + hits += current_hit + misses += current_miss + unknowns += current_unknown + + # TODO(amogkam): Optimized further by batching multiple references + # in a single `ray.get()` call. + with stats.iter_get_s.timer() if stats else nullcontext(): + block = ray.get(block_ref) + + yield BlockWithTiming(block=block, timings=timings) if stats: stats.iter_blocks_local = hits @@ -206,7 +228,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Block], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -235,7 +257,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Block], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -248,6 +270,7 @@ def __init__( self._drop_last = drop_last self._global_counter = 0 self._done_adding = False + self._pending_timings = BatchTimings() if shuffle_buffer_min_size is not None: self._batcher = ShufflingBatcher( @@ -272,12 +295,20 @@ def __next__(self) -> Batch: if can_yield: with timer: - next_batch = self._batcher.next_batch() + with self._pending_timings.batching: + next_batch = self._batcher.next_batch() res = Batch( - metadata=BatchMetadata(batch_idx=self._global_counter), + metadata=BatchMetadata( + batch_idx=self._global_counter, + timings=self._pending_timings, + ), data=next_batch, ) + res.metadata.timings.num_rows = BlockAccessor.for_block( + next_batch + ).num_rows() + self._pending_timings = BatchTimings() self._global_counter += 1 return res @@ -286,8 +317,9 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block = next(self._block_iter) - self._batcher.add(block) + block_with_timing = next(self._block_iter) + self._pending_timings.merge_fetch(block_with_timing.timings) + self._batcher.add(block_with_timing.block) except StopIteration: self._batcher.done_adding() self._done_adding = True @@ -306,12 +338,13 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: - with stats.iter_format_batch_s.timer() if stats else nullcontext(): - formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( - batch_format - ) - if ensure_copy: - formatted_data = _copy_batch(formatted_data) + with batch.metadata.timings.format: + with stats.iter_format_batch_s.timer() if stats else nullcontext(): + formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( + batch_format + ) + if ensure_copy: + formatted_data = _copy_batch(formatted_data) return dataclasses.replace(batch, data=formatted_data) @@ -359,8 +392,9 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with stats.iter_collate_batch_s.timer() if stats else nullcontext(): - collated_data = collate_fn(batch.data) + with batch.metadata.timings.collate: + with stats.iter_collate_batch_s.timer() if stats else nullcontext(): + collated_data = collate_fn(batch.data) return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -384,8 +418,9 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): - finalized_data = finalize_fn(batch.data) + with batch.metadata.timings.finalize: + with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): + finalized_data = finalize_fn(batch.data) return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 618ed8d6a376..dc5c32d1d500 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -1,6 +1,7 @@ import collections import copy import logging +import re import time from collections import defaultdict from contextlib import contextmanager @@ -60,6 +61,19 @@ StatsDict = Dict[str, List[BlockStats]] +def _create_iteration_tags(dataset_tag: Optional[str]): + dataset_tag = dataset_tag or "unknown_dataset" + tags = {"dataset": dataset_tag, "rank": "unknown"} + # Use findall + last match: the streaming-split index is always the + # trailing ``split_`` in the tag. The user-defined dataset name may + # itself contain ``split_`` so re.search (first match) could + # pick up the wrong one. + matches = re.findall(r"split_(\d+)", dataset_tag) + if matches: + tags["rank"] = matches[-1] + return tags + + def fmt(seconds: float) -> str: if seconds > 1: return str(round(seconds, 2)) + "s" @@ -181,17 +195,22 @@ def __init__(self): self._min: float = float("inf") self._max: float = 0 self._total_count: float = 0 + # Wall-clock window of the most recent timer() invocation. + # Used by overlap-based blocked attribution in iter_batches. + self.start_s: float = 0.0 + self.end_s: float = 0.0 # Bounded-memory percentile backend. add() forwards every value # to ``add_sample`` and ``percentile`` reads from it. self._distribution: DistributionTracker = DistributionTracker() @contextmanager def timer(self) -> None: - time_start = time.perf_counter() + self.start_s = time.perf_counter() try: yield finally: - self.add(time.perf_counter() - time_start) + self.end_s = time.perf_counter() + self.add(self.end_s - self.start_s) def add(self, value: float) -> None: self._total += value @@ -448,7 +467,7 @@ def __init__(self, max_stats=1000): # Per Node metrics self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics() - iter_tag_keys = ("dataset",) + iter_tag_keys = ("dataset", "rank") self.time_to_first_batch_s = Gauge( "data_iter_time_to_first_batch_seconds", @@ -488,6 +507,46 @@ def __init__(self, max_stats=1000): description="Seconds user thread is blocked by iter_batches()", tag_keys=iter_tag_keys, ) + self.iter_total_s = Gauge( + "data_iter_total_seconds", + description="Total wall-clock seconds spent in the dataset iterator", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_fetch_s = Gauge( + "data_iter_blocked_fetch_seconds", + description="Seconds user thread is blocked on block fetching", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_batching_s = Gauge( + "data_iter_blocked_batching_seconds", + description="Seconds user thread is blocked on batch creation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_format_s = Gauge( + "data_iter_blocked_format_seconds", + description="Seconds user thread is blocked on batch formatting", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_collate_s = Gauge( + "data_iter_blocked_collate_seconds", + description="Seconds user thread is blocked on batch collation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_finalize_s = Gauge( + "data_iter_blocked_finalize_seconds", + description="Seconds user thread is blocked on batch finalization", + tag_keys=iter_tag_keys, + ) + self.iter_batches_total = Gauge( + "data_iter_batches_total", + description="Total batches delivered to the user thread", + tag_keys=iter_tag_keys, + ) + self.iter_rows_total = Gauge( + "data_iter_rows_total", + description="Total rows delivered to the user thread", + tag_keys=iter_tag_keys, + ) self.iter_user_s = Gauge( "data_iter_user_seconds", description="Seconds spent in user code", @@ -725,9 +784,10 @@ def update_iteration_metrics( stats: "DatasetStats", dataset_tag, ): - tags = self._create_tags(dataset_tag) + tags = self._create_iteration_tags(dataset_tag) self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags) + self.iter_total_s.set(stats.iter_total_s.get(), tags) self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags) self.iter_get_s.set(stats.iter_get_s.get(), tags) self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags) @@ -748,6 +808,13 @@ def update_iteration_metrics( self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) + self.iter_blocked_fetch_s.set(stats.iter_blocked_fetch_s.get(), tags) + self.iter_blocked_batching_s.set(stats.iter_blocked_batching_s.get(), tags) + self.iter_blocked_format_s.set(stats.iter_blocked_format_s.get(), tags) + self.iter_blocked_collate_s.set(stats.iter_blocked_collate_s.get(), tags) + self.iter_blocked_finalize_s.set(stats.iter_blocked_finalize_s.get(), tags) + self.iter_batches_total.set(stats.iter_batches_total, tags) + self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) def register_dataset( @@ -941,6 +1008,9 @@ def _create_tags( tags["node_ip"] = node_ip_tag return tags + def _create_iteration_tags(self, dataset_tag: Optional[str]): + return _create_iteration_tags(dataset_tag) + def get_or_create_stats_actor() -> ActorHandle[_StatsActor]: """Each cluster will contain exactly 1 _StatsActor. This function @@ -1138,9 +1208,16 @@ def __init__( self.iter_finalize_batch_s: Timer = Timer() self.iter_time_to_first_batch_s: Timer = Timer() self.iter_total_blocked_s: Timer = Timer() + self.iter_blocked_fetch_s: Timer = Timer() + self.iter_blocked_batching_s: Timer = Timer() + self.iter_blocked_format_s: Timer = Timer() + self.iter_blocked_collate_s: Timer = Timer() + self.iter_blocked_finalize_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_initialize_s: Timer = Timer() self.iter_total_s: Timer = Timer() + self.iter_batches_total: int = 0 + self.iter_rows_total: int = 0 self.extra_metrics = {} # Block fetch stats during iteration. @@ -1196,6 +1273,13 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_blocks_remote, self.iter_unknown_location, self.iter_prefetched_bytes, + self.iter_blocked_fetch_s, + self.iter_blocked_batching_s, + self.iter_blocked_format_s, + self.iter_blocked_collate_s, + self.iter_blocked_finalize_s, + self.iter_batches_total, + self.iter_rows_total, ) stats_summary_parents = [] @@ -1878,6 +1962,15 @@ class IterStatsSummary: iter_unknown_location: int # Current bytes of prefetched blocks in the iterator iter_prefetched_bytes: int + # Per-stage training-thread blocked attribution timers. + blocked_fetch_time: Timer + blocked_batching_time: Timer + blocked_format_time: Timer + blocked_collate_time: Timer + blocked_finalize_time: Timer + # Cumulative batch and row counters. + batches_total: int + rows_total: int def __str__(self) -> str: return self.to_string() @@ -1984,6 +2077,24 @@ def to_string(self) -> str: out += "Streaming split coordinator overhead time: " out += f"{fmt(self.streaming_split_coord_time.get())}\n" + # Per-stage training-thread blocked attribution. + stage_totals = [ + ("block fetch (ray.get)", self.blocked_fetch_time), + ("batching", self.blocked_batching_time), + ("format", self.blocked_format_time), + ("collate", self.blocked_collate_time), + ("finalize (host->device)", self.blocked_finalize_time), + ] + active_stages = [(name, t) for name, t in stage_totals if t.get() > 0] + if active_stages: + out += "\nPer-stage training-thread blocked time breakdown:\n" + for stage_name, timer in active_stages: + out += " * {}: {}\n".format(stage_name, fmt(timer.get())) + if self.batches_total: + out += "Total batches consumed: {}\n".format(self.batches_total) + if self.rows_total: + out += "Total rows consumed: {}\n".format(self.rows_total) + return out def __repr__(self, level=0) -> str: diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 1f69962435ad..9cfaea2f8dfc 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -22,7 +22,7 @@ from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.logical.interfaces import LogicalPlan from ray.data._internal.logical.operators import InputData -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, _StatsManager from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format from ray.data.collate_fn import ( ArrowBatchCollateFn, @@ -297,6 +297,7 @@ def callback(num_bytes: int) -> None: yield from batch_iterator if stats: stats.iter_total_s.add(time.perf_counter() - time_start) + _StatsManager.update_iteration_metrics(stats, dataset_tag) finally: # On early exit (e.g. ``break`` in the for-loop), the inner # ``_ClosingIterator`` would only shut down the executor via diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index b0b55d1cb8fb..63128be388e3 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -11,7 +11,9 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, BlockPrefetcher, + StageTiming, ) from ray.data._internal.block_batching.iter_batches import ( BatchIterator, @@ -114,6 +116,382 @@ def test_restore_from_original_order(): assert idx == [0, 1, 2, 3] +def test_restore_original_order_stats(): + base_iterator = [ + Batch(BatchMetadata(batch_idx=2), None), + Batch(BatchMetadata(batch_idx=0), None), + Batch(BatchMetadata(batch_idx=1), None), + ] + + ordered = list(restore_original_order(iter(base_iterator))) + + assert [batch.metadata.batch_idx for batch in ordered] == [0, 1, 2] + + +def test_report_batch_timings_overlap_attribution(): + stats = DatasetStats(metadata={}, parent=None) + batch_iterator = BatchIterator(iter([]), stats=stats) + timings = BatchTimings(num_rows=8) + timings.fetch = StageTiming(start_s=10.0, end_s=20.0) + timings.batching = StageTiming(start_s=20.0, end_s=30.0) + timings.format = StageTiming(start_s=30.0, end_s=40.0) + timings.finalize = StageTiming(start_s=50.0, end_s=60.0) + batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + batch_iterator._report_batch_timings( + batch, blocked_start_s=15.0, blocked_end_s=35.0 + ) + + assert stats.iter_blocked_fetch_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(10.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_collate_s.get() == 0 + assert stats.iter_blocked_finalize_s.get() == 0 + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 8 + + +def _make_batch_with_timings( + fetch_start=0.0, + fetch_end=0.0, + batching_start=0.0, + batching_end=0.0, + format_start=0.0, + format_end=0.0, + collate_start=0.0, + collate_end=0.0, + finalize_start=0.0, + finalize_end=0.0, + num_rows=0, +): + """Helper to construct a Batch with specific stage timing windows.""" + timings = BatchTimings(num_rows=num_rows) + timings.fetch = StageTiming(start_s=fetch_start, end_s=fetch_end) + timings.batching = StageTiming(start_s=batching_start, end_s=batching_end) + timings.format = StageTiming(start_s=format_start, end_s=format_end) + timings.collate = StageTiming(start_s=collate_start, end_s=collate_end) + timings.finalize = StageTiming(start_s=finalize_start, end_s=finalize_end) + return Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + +def _make_report_iterator(stats): + """Create a BatchIterator wired to the given stats without a real pipeline.""" + it = BatchIterator.__new__(BatchIterator) + it._stats = stats + return it + + +class TestReportBatchTimingsEdgeCases: + """Edge case tests for overlap-based blocked attribution.""" + + def test_zero_overlap_stage_finished_before_blocked(self): + """Fetch [0, 1.5] finished before training blocked at t=2 → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.5) + it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == 0.0 + + def test_zero_overlap_blocked_before_stage(self): + """Training blocked [0, 1], stage ran [2, 3] → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(format_start=2.0, format_end=3.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=1.0) + assert stats.iter_blocked_format_s.get() == 0.0 + + def test_partial_overlap(self): + """Fetch [0, 2], blocked [1, 3] → overlap = min(2,3)-max(0,1) = 1.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=2.0) + it._report_batch_timings(batch, blocked_start_s=1.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + + def test_full_overlap_stage_inside_blocked(self): + """Stage [1, 2] entirely inside blocked [0, 3] → full 1.0 credit.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(batching_start=1.0, batching_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + + def test_no_collate_fn_zero_attribution(self): + """collate stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(format_start=1.0, format_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == 0.0 + + def test_no_finalize_fn_zero_attribution(self): + """finalize stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(collate_start=1.0, collate_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == 0.0 + + def test_prefetch_hides_fetch_from_training(self): + """Effective prefetch: fetch done before training blocks → 0 fetch attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.5, + collate_start=2.3, + collate_end=2.6, + ) + # Training only starts blocking at t=2 (prefetch worked) + it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=2.6) + assert stats.iter_blocked_fetch_s.get() == 0.0 + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.3) + + def test_accumulation_across_batches(self): + """Two batches each contribute to fetch — values accumulate.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + # Batch 1: fetch [0,1], blocked [0,2] → overlap 1.0 + b1 = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.0, num_rows=10) + it._report_batch_timings(b1, blocked_start_s=0.0, blocked_end_s=2.0) + # Batch 2: fetch [5,6], blocked [5,7] → overlap 1.0 + b2 = _make_batch_with_timings(fetch_start=5.0, fetch_end=6.0, num_rows=20) + it._report_batch_timings(b2, blocked_start_s=5.0, blocked_end_s=7.0) + + assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 30 + + def test_overlap_invariant_sum_leq_total(self): + """sum(iter_blocked_*) <= iter_total_blocked_s always holds.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + num_rows=5, + ) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + total = stats.iter_total_blocked_s.get() + sum_stages = ( + stats.iter_blocked_fetch_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + ) + assert sum_stages <= total + 1e-9 + + def test_blocked_inside_stage(self): + """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=10.0) + it._report_batch_timings(batch, blocked_start_s=3.0, blocked_end_s=5.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + + def test_all_stages_simultaneous_overlap(self): + """Multiple stages overlap with blocked window simultaneously.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + collate_start=3.0, + collate_end=4.0, + finalize_start=4.0, + finalize_end=5.0, + num_rows=100, + ) + # Blocked window covers all stages + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(1.0) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 100 + + +class TestStageTimingRecord: + """Tests for StageTiming.record() behavior.""" + + def test_context_manager_captures_window(self): + """Using as context manager captures start_s and end_s.""" + t = StageTiming() + with t: + pass + assert t.start_s > 0 + assert t.end_s >= t.start_s + + def test_timer_context_manager(self): + """The timer() method works as a context manager too.""" + t = StageTiming() + with t.timer(): + pass + assert t.start_s > 0 + assert t.end_s >= t.start_s + + def test_default_values(self): + """Unrecorded StageTiming has start_s=0 and end_s=0.""" + t = StageTiming() + assert t.start_s == 0.0 + assert t.end_s == 0.0 + + +class TestMergeFetch: + """Tests for BatchTimings.merge_fetch() with multiple blocks per batch.""" + + def test_merge_single_block(self): + """Merging a single block preserves its fetch window.""" + dst = BatchTimings() + src = BatchTimings() + src.fetch = StageTiming(start_s=1.0, end_s=2.0) + dst.merge_fetch(src) + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 2.0 + + def test_merge_multiple_blocks_expands_window(self): + """Merging multiple blocks produces the union window.""" + dst = BatchTimings() + + # Block 1: fetched [1.0, 2.0] + src1 = BatchTimings() + src1.fetch = StageTiming(start_s=1.0, end_s=2.0) + dst.merge_fetch(src1) + + # Block 2: fetched [3.0, 4.0] + src2 = BatchTimings() + src2.fetch = StageTiming(start_s=3.0, end_s=4.0) + dst.merge_fetch(src2) + + # Block 3: fetched [5.0, 6.0] + src3 = BatchTimings() + src3.fetch = StageTiming(start_s=5.0, end_s=6.0) + dst.merge_fetch(src3) + + # Union: [1.0, 6.0] + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 6.0 + + def test_merge_unrecorded_block_ignored(self): + """Merging a block with no fetch timing (start_s=0) is a no-op.""" + dst = BatchTimings() + dst.fetch = StageTiming(start_s=2.0, end_s=3.0) + + src = BatchTimings() # fetch defaults to (0.0, 0.0) + dst.merge_fetch(src) + + assert dst.fetch.start_s == 2.0 + assert dst.fetch.end_s == 3.0 + + def test_merge_overlapping_blocks(self): + """Overlapping fetch windows are correctly merged.""" + dst = BatchTimings() + + src1 = BatchTimings() + src1.fetch = StageTiming(start_s=1.0, end_s=5.0) + dst.merge_fetch(src1) + + src2 = BatchTimings() + src2.fetch = StageTiming(start_s=3.0, end_s=7.0) + dst.merge_fetch(src2) + + # Union: [1.0, 7.0] + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 7.0 + + def test_merge_into_empty_destination(self): + """Merging into an empty BatchTimings takes the source window.""" + dst = BatchTimings() # fetch = (0.0, 0.0) + src = BatchTimings() + src.fetch = StageTiming(start_s=10.0, end_s=20.0) + dst.merge_fetch(src) + assert dst.fetch.start_s == 10.0 + assert dst.fetch.end_s == 20.0 + + +class TestEndToEndTimingPropagation: + """Tests that stage timings propagate correctly through the full pipeline.""" + + def test_batch_carries_timings_through_pipeline(self): + """A Batch's metadata.timings carries all stage windows.""" + timings = BatchTimings(num_rows=50) + timings.fetch = StageTiming(start_s=1.0, end_s=2.0) + timings.batching = StageTiming(start_s=2.0, end_s=3.0) + timings.format = StageTiming(start_s=3.0, end_s=4.0) + timings.collate = StageTiming(start_s=4.0, end_s=5.0) + timings.finalize = StageTiming(start_s=5.0, end_s=6.0) + + batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + # Verify all stages are accessible via stages() iterator + stage_dict = dict(batch.metadata.timings.stages()) + assert len(stage_dict) == 5 + assert stage_dict["fetch"].start_s == 1.0 + assert stage_dict["batching"].end_s == 3.0 + assert stage_dict["format"].start_s == 3.0 + assert stage_dict["collate"].end_s == 5.0 + assert stage_dict["finalize"].start_s == 5.0 + assert batch.metadata.timings.num_rows == 50 + + def test_full_pipeline_attribution(self): + """End-to-end: all 5 stages with realistic timing, full overlap.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=0.5, + batching_start=0.5, + batching_end=1.0, + format_start=1.0, + format_end=2.0, + collate_start=2.0, + collate_end=2.5, + finalize_start=2.5, + finalize_end=3.0, + num_rows=256, + ) + + # Blocked window covers all stages + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + # Each stage gets its full duration + assert stats.iter_blocked_fetch_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_batching_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(0.5) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 256 + + # Invariant: sum = 3.0 <= total_blocked = 5.0 + sum_stages = ( + stats.iter_blocked_fetch_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + ) + assert sum_stages == pytest.approx(3.0) + assert sum_stages <= stats.iter_total_blocked_s.get() + 1e-9 + + def test_finalize_fn_uses_single_thread(ray_start_regular_shared): """Tests that finalize_fn is not run with multiple threads.""" ref_bundles_iter = ref_bundle_generator(num_blocks=20, num_rows=2) @@ -194,6 +572,27 @@ def collate_fn(batch: pd.DataFrame): assert concat_df["foo"].iloc[i + 1] >= concat_df["foo"].iloc[i] +def test_iter_batches_counts_rows_at_pipeline_exit(ray_start_regular_shared): + stats = DatasetStats(metadata={}, parent=None) + ref_bundles_iter = ref_bundle_generator(num_blocks=4, num_rows=2) + + output_batches = list( + BatchIterator( + ref_bundles_iter, + stats=stats, + batch_size=3, + prefetch_batches=0, + batch_format="pandas", + drop_last=True, + ) + ) + + assert len(output_batches) == 2 + assert [len(batch) for batch in output_batches] == [3, 3] + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 6 + + def test_iter_batches_e2e_async(ray_start_regular_shared): """We add time.sleep in 3 places: 1. In the base generator to simulate streaming executor blocking on next results. diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 6ead5741f0e1..7a6223b06f78 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -13,7 +13,12 @@ import pytest import ray -from ray.data._internal.block_batching.interfaces import Batch, BatchMetadata +from ray.data._internal.block_batching.interfaces import ( + Batch, + BatchMetadata, + BatchTimings, + BlockWithTiming, +) from ray.data._internal.block_batching.util import ( _calculate_ref_hits, blocks_to_batches, @@ -37,7 +42,9 @@ def test_resolve_block_refs(ray_start_regular_shared): block_refs = [ray.put(0), ray.put(1), ray.put(2)] resolved_iter = resolve_block_refs(iter(block_refs)) - assert list(resolved_iter) == [0, 1, 2] + resolved = list(resolved_iter) + assert all(isinstance(b, BlockWithTiming) for b in resolved) + assert [b.block for b in resolved] == [0, 1, 2] @pytest.mark.parametrize("block_size", [1, 10]) @@ -45,10 +52,14 @@ def test_resolve_block_refs(ray_start_regular_shared): def test_blocks_to_batches(block_size, drop_last): num_blocks = 5 block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) + # Wrap raw blocks in BlockWithTiming as blocks_to_batches now expects + wrapped_blocks = ( + BlockWithTiming(block=b, timings=BatchTimings()) for b in block_iter + ) batch_size = 3 batch_iter = list( - blocks_to_batches(block_iter, batch_size=batch_size, drop_last=drop_last) + blocks_to_batches(wrapped_blocks, batch_size=batch_size, drop_last=drop_last) ) if drop_last: diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index cb4c31553541..245c0bfed8d4 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -36,6 +36,7 @@ OperatorStatsSummary, StatsSummary, Timer, + _create_iteration_tags, _StatsActor, get_or_create_stats_actor, ) @@ -1878,6 +1879,160 @@ def test_stats_actor_iter_metrics(): assert f"dataset_{ds._uuid}_0" == update_fn.call_args_list[-1].args[1] +def test_create_iteration_tags_extracts_rank(): + assert _create_iteration_tags("train_abc_split_2") == { + "dataset": "train_abc_split_2", + "rank": "2", + } + assert _create_iteration_tags("dataset_without_split") == { + "dataset": "dataset_without_split", + "rank": "unknown", + } + # User-defined dataset name may contain split_; the trailing + # split index (from streaming split coordinator) should be used. + assert _create_iteration_tags("my_split_3_data_abc123_split_5") == { + "dataset": "my_split_3_data_abc123_split_5", + "rank": "5", + } + + +def test_update_iteration_metrics_exports_new_iter_metrics(): + stats = DatasetStats(metadata={}, parent=None) + stats.iter_total_s.add(11.0) + stats.iter_blocked_fetch_s.add(1.0) + stats.iter_blocked_batching_s.add(2.0) + stats.iter_blocked_format_s.add(3.0) + stats.iter_blocked_collate_s.add(4.0) + stats.iter_blocked_finalize_s.add(5.0) + stats.iter_batches_total = 7 + stats.iter_rows_total = 8 + + actor = _StatsActor.__ray_metadata__.modified_class.__new__( + _StatsActor.__ray_metadata__.modified_class + ) + recorded = {} + + class FakeGauge: + def __init__(self, name): + self.name = name + + def set(self, value, tags): + recorded[self.name] = (value, tags) + + for attr in [ + "iter_initialize_s", + "iter_total_s", + "iter_get_ref_bundles_s", + "iter_get_s", + "iter_next_batch_s", + "iter_format_batch_s", + "iter_collate_batch_s", + "iter_finalize_batch_s", + "iter_blocks_local", + "iter_blocks_remote", + "iter_unknown_location", + "iter_prefetched_bytes", + "iter_block_fetching_s", + "iter_batch_shaping_s", + "iter_batch_formatting_s", + "iter_batch_collating_s", + "iter_batch_finalizing_s", + "time_to_first_batch_s", + "iter_total_blocked_s", + "iter_blocked_fetch_s", + "iter_blocked_batching_s", + "iter_blocked_format_s", + "iter_blocked_collate_s", + "iter_blocked_finalize_s", + "iter_batches_total", + "iter_rows_total", + "iter_user_s", + ]: + setattr(actor, attr, FakeGauge(attr)) + + actor.update_iteration_metrics(stats, "train_dataset_split_3") + + expected_tags = {"dataset": "train_dataset_split_3", "rank": "3"} + assert recorded["iter_total_s"] == (11.0, expected_tags) + assert recorded["iter_blocked_fetch_s"] == (1.0, expected_tags) + assert recorded["iter_blocked_batching_s"] == (2.0, expected_tags) + assert recorded["iter_blocked_format_s"] == (3.0, expected_tags) + assert recorded["iter_blocked_collate_s"] == (4.0, expected_tags) + assert recorded["iter_blocked_finalize_s"] == (5.0, expected_tags) + assert recorded["iter_batches_total"] == (7, expected_tags) + assert recorded["iter_rows_total"] == (8, expected_tags) + + +def test_iter_stats_summary_has_new_fields(): + """IterStatsSummary includes per-stage blocked timers and counters.""" + stats = DatasetStats(metadata={}, parent=None) + summary = stats.to_summary() + iter_summary = summary.iter_stats + + assert hasattr(iter_summary, "blocked_fetch_time") + assert hasattr(iter_summary, "blocked_batching_time") + assert hasattr(iter_summary, "blocked_format_time") + assert hasattr(iter_summary, "blocked_collate_time") + assert hasattr(iter_summary, "blocked_finalize_time") + assert hasattr(iter_summary, "batches_total") + assert hasattr(iter_summary, "rows_total") + + +def test_iter_stats_summary_reflects_accumulated_values(): + """IterStatsSummary carries the accumulated timer values.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(0.5) + stats.iter_blocked_batching_s.add(0.2) + stats.iter_batches_total = 10 + stats.iter_rows_total = 320 + + summary = stats.to_summary().iter_stats + assert summary.blocked_fetch_time.get() == pytest.approx(0.5) + assert summary.blocked_batching_time.get() == pytest.approx(0.2) + assert summary.batches_total == 10 + assert summary.rows_total == 320 + + +def test_iter_stats_to_string_shows_stage_breakdown(): + """to_string() renders per-stage breakdown when values are non-zero.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(1.5) + stats.iter_blocked_format_s.add(0.8) + stats.iter_batches_total = 5 + stats.iter_rows_total = 160 + stats.iter_total_blocked_s.add(2.3) + + text = str(stats.to_summary().iter_stats) + assert "block fetch" in text + assert "format" in text + assert "Total batches consumed: 5" in text + assert "Total rows consumed: 160" in text + assert "Per-stage training-thread blocked time breakdown" in text + + +def test_iter_stats_to_string_omits_zero_stages(): + """to_string() omits stages with zero values from the breakdown.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(0.5) + stats.iter_total_blocked_s.add(0.5) + + text = str(stats.to_summary().iter_stats) + assert "block fetch" in text + # Zero stages should not appear + assert "batching" not in text + assert "collate" not in text + assert "restore order" not in text + + +def test_iter_stats_to_string_no_breakdown_when_all_zero(): + """When all blocked_* stages are zero, no breakdown section appears.""" + stats = DatasetStats(metadata={}, parent=None) + text = str(stats.to_summary().iter_stats) + assert "Per-stage training-thread blocked time breakdown" not in text + assert "Total batches consumed" not in text + assert "Total rows consumed" not in text + + def test_dataset_name_and_id(): # Test deprecated APIs: _set_name and _name ds = ray.data.range(1)