From 071ebd6507503f480618b269abd182577e670f48 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Wed, 17 Jun 2026 23:57:00 +0800 Subject: [PATCH 01/12] [Data] Add per-stage training-thread blocking attribution and pipeline observability to iter_batches Implements overlap-based latency attribution for Ray Data's iter_batches pipeline, addressing #64132 and RFC #63911. Each pipeline stage (fetch, batching, format, collate, finalize, restore_order) records an independent (start_s, end_s) time window. The training thread captures its own blocked window around next(). Attribution per stage is the overlap of the two windows, correctly handling prefetch > 1. New Prometheus metrics (14 total): - data_iter_blocked_{fetch,batching,format,collate,finalize,restore_order}_seconds - data_iter_batches_total, data_iter_rows_total - data_iter_total_seconds, data_iter_restore_order_buffer_peak - data_iter_shuffle_buffer_{rows,compactions_total,compaction_seconds} - data_iter_prefetch_queue_depth Also adds: - Per-stage breakdown rendering in IterStatsSummary.to_string() - Rank extraction from dataset tags for Prometheus labels - Final metrics flush on iterator completion Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/batcher.py | 9 + .../_internal/block_batching/interfaces.py | 58 ++- .../_internal/block_batching/iter_batches.py | 64 ++- .../ray/data/_internal/block_batching/util.py | 53 ++- python/ray/data/_internal/stats.py | 122 ++++- python/ray/data/iterator.py | 3 +- .../tests/block_batching/test_iter_batches.py | 432 ++++++++++++++++++ python/ray/data/tests/test_stats.py | 159 +++++++ 8 files changed, 878 insertions(+), 22 deletions(-) diff --git a/python/ray/data/_internal/batcher.py b/python/ray/data/_internal/batcher.py index c097ee668de6..7ff6ccf42c11 100644 --- a/python/ray/data/_internal/batcher.py +++ b/python/ray/data/_internal/batcher.py @@ -1,3 +1,4 @@ +import time import warnings from typing import Optional @@ -235,6 +236,8 @@ def __init__( self._total_object_store_nbytes = get_total_obj_store_mem_on_node() self._total_num_rows_added = 0 self._total_nbytes_added = 0 + self.compactions_total = 0 + self.compaction_time_s = 0.0 def add(self, block: Block): """Add a block to the shuffle buffer. @@ -320,6 +323,9 @@ def _num_rows(self) -> int: """ return self._num_compacted_rows() + self._num_uncompacted_rows() + def num_rows(self) -> int: + return self._num_rows() + def _num_compacted_rows(self) -> int: """Return number of unyielded rows in the compacted buffer.""" if self._shuffle_buffer is None: @@ -341,6 +347,7 @@ def next_batch(self) -> Block: self._done_adding or self._num_compacted_rows() <= self._min_rows_to_yield_batch ): + compaction_start_s = time.perf_counter() if self._shuffle_buffer is not None and self._batch_head < len( self._shuffled_indices ): @@ -363,6 +370,8 @@ def next_batch(self) -> Block: self._builder = DelegatingBlockBuilder() self._batch_head = 0 + self.compactions_total += 1 + self.compaction_time_s += time.perf_counter() - compaction_start_s assert self._shuffle_buffer is not None assert self._shuffled_indices is not None diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 4f0bed6b3dd4..10587586c941 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,11 +1,63 @@ import abc -from dataclasses import dataclass -from typing import Any, List +from dataclasses import dataclass, field +from typing import Any, Iterable, List, Tuple from ray.data.block import Block, DataBatch from ray.types import ObjectRef +@dataclass +class StageTiming: + """Wall-clock window for a batch-processing stage.""" + + start_s: float = 0.0 + end_s: float = 0.0 + + def record(self, start_s: float, end_s: float) -> None: + if self.start_s == 0.0: + self.start_s = start_s + self.end_s = end_s + + +@dataclass +class BatchTimings: + fetch: StageTiming = field(default_factory=StageTiming) + batching: StageTiming = field(default_factory=StageTiming) + format: StageTiming = field(default_factory=StageTiming) + collate: StageTiming = field(default_factory=StageTiming) + finalize: StageTiming = field(default_factory=StageTiming) + restore_order: StageTiming = field(default_factory=StageTiming) + num_rows: int = 0 + + def stages(self) -> Iterable[Tuple[str, StageTiming]]: + return ( + ("fetch", self.fetch), + ("batching", self.batching), + ("format", self.format), + ("collate", self.collate), + ("finalize", self.finalize), + ("restore_order", self.restore_order), + ) + + def merge_fetch(self, other: "BatchTimings") -> None: + self._merge_stage(self.fetch, other.fetch) + + @staticmethod + def _merge_stage(dst: StageTiming, src: StageTiming) -> None: + if src.start_s == 0.0: + return + if dst.start_s == 0.0 or src.start_s < dst.start_s: + dst.start_s = src.start_s + if src.end_s > dst.end_s: + dst.end_s = src.end_s + + +@dataclass +class BlockWithTiming: + block: Block + timings: BatchTimings = field(default_factory=BatchTimings) + + @dataclass class BatchMetadata: """Metadata associated with a batch. @@ -13,9 +65,11 @@ class BatchMetadata: Attributes: batch_idx: The global index of this batch so that downstream operations can maintain ordering. + timings: Pipeline-stage timing windows for this batch. """ batch_idx: int + timings: BatchTimings = field(default_factory=BatchTimings) @dataclass diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index f9bf0076d2af..66e55437d68f 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -175,8 +175,10 @@ def _prefetch_blocks( def _resolve_block_refs( self, block_refs: Iterator[ObjectRef[Block]] - ) -> Iterator[Block]: - return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats) + ) -> Iterator[Any]: + return resolve_block_refs( + block_ref_iter=block_refs, stats=self._stats, record_timings=True + ) def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]: return blocks_to_batches( @@ -216,7 +218,7 @@ def _finalize_batches( def _restore_original_batch_order( self, batches: Iterator[Batch] ) -> Iterator[Batch]: - return restore_original_order(batches) + return restore_original_order(batches, stats=self._stats) def _pipeline(self, ref_bundles: Iterator[RefBundle]) -> Iterator[Batch]: # Step 1: Prefetch logical batches locally. @@ -248,16 +250,36 @@ def _iter_batches(self) -> Iterator[DataBatch]: self.before_epoch_start() while True: + blocked_start_s = time.perf_counter() with self.get_next_batch_context(): try: batch = next(batch_iter) except StopIteration: break + blocked_end_s = time.perf_counter() + self._report_batch_timings(batch, blocked_start_s, blocked_end_s) with self.yield_batch_context(batch): yield batch.data self.after_epoch_end() + def _report_batch_timings( + self, batch: Batch, blocked_start_s: float, blocked_end_s: float + ) -> None: + if self._stats is None: + return + timings = batch.metadata.timings + for name, stage in timings.stages(): + if stage.start_s == 0.0 and stage.end_s == 0.0: + continue + overlap_s = min(stage.end_s, blocked_end_s) - max( + stage.start_s, blocked_start_s + ) + if overlap_s > 0: + getattr(self._stats, f"iter_blocked_{name}_s").add(overlap_s) + self._stats.iter_batches_total += 1 + self._stats.iter_rows_total += timings.num_rows + def __iter__(self) -> Iterator[DataBatch]: return self._iter_batches() @@ -452,7 +474,9 @@ def get_next_ref_bundle() -> RefBundle: prefetcher.stop() -def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: +def restore_original_order( + batch_iter: Iterator[Batch], stats: Optional[DatasetStats] = None +) -> Iterator[Batch]: """Restores the original order of the provided `batch_iter` This function will yield items from `base_iterator` in the correct order based on @@ -463,13 +487,31 @@ def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: """ next_index_required = 0 buffer: Dict[int, Batch] = {} - for batch in batch_iter: - assert batch.metadata.batch_idx not in buffer - buffer[batch.metadata.batch_idx] = batch + restore_wait_start_s: Optional[float] = None + source_exhausted = False + + while True: while next_index_required in buffer: - yield buffer.pop(next_index_required) + next_batch = buffer.pop(next_index_required) + if restore_wait_start_s is not None: + next_batch.metadata.timings.restore_order.record( + restore_wait_start_s, time.perf_counter() + ) + restore_wait_start_s = None + yield next_batch next_index_required += 1 - while next_index_required in buffer: - yield buffer.pop(next_index_required) - next_index_required += 1 + if source_exhausted: + break + + if buffer and restore_wait_start_s is None: + restore_wait_start_s = time.perf_counter() + + try: + batch = next(batch_iter) + except StopIteration: + source_exhausted = True + continue + + assert batch.metadata.batch_idx not in buffer + buffer[batch.metadata.batch_idx] = batch diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 8a42cde7871e..fb9f1f3aa11a 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,6 +3,7 @@ import logging import queue import threading +import time from contextlib import nullcontext from typing import ( Any, @@ -14,6 +15,7 @@ Optional, Tuple, TypeVar, + Union, ) import ray @@ -22,8 +24,11 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, BlockPrefetcher, + BlockWithTiming, CollatedBatch, + StageTiming, ) from ray.data._internal.stats import DatasetStats from ray.data.block import Block, BlockAccessor, DataBatch @@ -170,18 +175,27 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: return 0, 0, 0 +def _record_stage_window(stage: StageTiming, start_s: float, end_s: float) -> None: + stage.record(start_s, end_s) + + def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[DatasetStats] = None, -) -> Iterator[Block]: + record_timings: bool = False, +) -> Iterator[Union[Block, BlockWithTiming]]: """Resolves the block references for each logical batch. Args: block_ref_iter: An iterator over block object references. stats: An optional stats object to recording block hits and misses. + record_timings: If True, wrap each resolved block in a + ``BlockWithTiming`` carrying the per-block fetch window. Yields: - Block: The resolved blocks for each block reference. + Union[Block, BlockWithTiming]: The resolved blocks. When + *record_timings* is ``True`` each block is wrapped in a + ``BlockWithTiming``; otherwise raw ``Block`` instances are yielded. """ hits = 0 misses = 0 @@ -195,9 +209,16 @@ def resolve_block_refs( # TODO(amogkam): Optimized further by batching multiple references in a single # `ray.get()` call. + start_s = time.perf_counter() with stats.iter_get_s.timer() if stats else nullcontext(): block = ray.get(block_ref) - yield block + end_s = time.perf_counter() + if record_timings: + timings = BatchTimings() + _record_stage_window(timings.fetch, start_s, end_s) + yield BlockWithTiming(block=block, timings=timings) + else: + yield block if stats: stats.iter_blocks_local = hits @@ -206,7 +227,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Block], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -235,7 +256,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Block], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -248,6 +269,7 @@ def __init__( self._drop_last = drop_last self._global_counter = 0 self._done_adding = False + self._pending_timings = BatchTimings() if shuffle_buffer_min_size is not None: self._batcher = ShufflingBatcher( @@ -272,12 +294,22 @@ def __next__(self) -> Batch: if can_yield: with timer: + start_s = time.perf_counter() next_batch = self._batcher.next_batch() + end_s = time.perf_counter() res = Batch( - metadata=BatchMetadata(batch_idx=self._global_counter), + metadata=BatchMetadata( + batch_idx=self._global_counter, + timings=self._pending_timings, + ), data=next_batch, ) + _record_stage_window(res.metadata.timings.batching, start_s, end_s) + res.metadata.timings.num_rows = BlockAccessor.for_block( + next_batch + ).num_rows() + self._pending_timings = BatchTimings() self._global_counter += 1 return res @@ -287,6 +319,9 @@ def __next__(self) -> Batch: try: # NOTE: Block ref is released immediately block = next(self._block_iter) + if isinstance(block, BlockWithTiming): + self._pending_timings.merge_fetch(block.timings) + block = block.block self._batcher.add(block) except StopIteration: self._batcher.done_adding() @@ -306,12 +341,14 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: + start_s = time.perf_counter() with stats.iter_format_batch_s.timer() if stats else nullcontext(): formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( batch_format ) if ensure_copy: formatted_data = _copy_batch(formatted_data) + _record_stage_window(batch.metadata.timings.format, start_s, time.perf_counter()) return dataclasses.replace(batch, data=formatted_data) @@ -359,8 +396,10 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: + start_s = time.perf_counter() with stats.iter_collate_batch_s.timer() if stats else nullcontext(): collated_data = collate_fn(batch.data) + _record_stage_window(batch.metadata.timings.collate, start_s, time.perf_counter()) return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -384,8 +423,10 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: + start_s = time.perf_counter() with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): finalized_data = finalize_fn(batch.data) + _record_stage_window(batch.metadata.timings.finalize, start_s, time.perf_counter()) return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 618ed8d6a376..7afdb967aeaa 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -1,6 +1,7 @@ import collections import copy import logging +import re import time from collections import defaultdict from contextlib import contextmanager @@ -60,6 +61,19 @@ StatsDict = Dict[str, List[BlockStats]] +def _create_iteration_tags(dataset_tag: Optional[str]): + dataset_tag = dataset_tag or "unknown_dataset" + tags = {"dataset": dataset_tag, "rank": "unknown"} + # Use findall + last match: the streaming-split index is always the + # trailing ``split_`` in the tag. The user-defined dataset name may + # itself contain ``split_`` so re.search (first match) could + # pick up the wrong one. + matches = re.findall(r"split_(\d+)", dataset_tag) + if matches: + tags["rank"] = matches[-1] + return tags + + def fmt(seconds: float) -> str: if seconds > 1: return str(round(seconds, 2)) + "s" @@ -448,7 +462,7 @@ def __init__(self, max_stats=1000): # Per Node metrics self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics() - iter_tag_keys = ("dataset",) + iter_tag_keys = ("dataset", "rank") self.time_to_first_batch_s = Gauge( "data_iter_time_to_first_batch_seconds", @@ -488,6 +502,51 @@ def __init__(self, max_stats=1000): description="Seconds user thread is blocked by iter_batches()", tag_keys=iter_tag_keys, ) + self.iter_total_s = Gauge( + "data_iter_total_seconds", + description="Total wall-clock seconds spent in the dataset iterator", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_fetch_s = Gauge( + "data_iter_blocked_fetch_seconds", + description="Seconds user thread is blocked on block fetching", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_batching_s = Gauge( + "data_iter_blocked_batching_seconds", + description="Seconds user thread is blocked on batch creation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_format_s = Gauge( + "data_iter_blocked_format_seconds", + description="Seconds user thread is blocked on batch formatting", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_collate_s = Gauge( + "data_iter_blocked_collate_seconds", + description="Seconds user thread is blocked on batch collation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_finalize_s = Gauge( + "data_iter_blocked_finalize_seconds", + description="Seconds user thread is blocked on batch finalization", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_restore_order_s = Gauge( + "data_iter_blocked_restore_order_seconds", + description="Seconds user thread is blocked on restoring batch order", + tag_keys=iter_tag_keys, + ) + self.iter_batches_total = Gauge( + "data_iter_batches_total", + description="Total batches delivered to the user thread", + tag_keys=iter_tag_keys, + ) + self.iter_rows_total = Gauge( + "data_iter_rows_total", + description="Total rows delivered to the user thread", + tag_keys=iter_tag_keys, + ) self.iter_user_s = Gauge( "data_iter_user_seconds", description="Seconds spent in user code", @@ -725,9 +784,10 @@ def update_iteration_metrics( stats: "DatasetStats", dataset_tag, ): - tags = self._create_tags(dataset_tag) + tags = self._create_iteration_tags(dataset_tag) self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags) + self.iter_total_s.set(stats.iter_total_s.get(), tags) self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags) self.iter_get_s.set(stats.iter_get_s.get(), tags) self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags) @@ -748,6 +808,16 @@ def update_iteration_metrics( self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) + self.iter_blocked_fetch_s.set(stats.iter_blocked_fetch_s.get(), tags) + self.iter_blocked_batching_s.set(stats.iter_blocked_batching_s.get(), tags) + self.iter_blocked_format_s.set(stats.iter_blocked_format_s.get(), tags) + self.iter_blocked_collate_s.set(stats.iter_blocked_collate_s.get(), tags) + self.iter_blocked_finalize_s.set(stats.iter_blocked_finalize_s.get(), tags) + self.iter_blocked_restore_order_s.set( + stats.iter_blocked_restore_order_s.get(), tags + ) + self.iter_batches_total.set(stats.iter_batches_total, tags) + self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) def register_dataset( @@ -941,6 +1011,9 @@ def _create_tags( tags["node_ip"] = node_ip_tag return tags + def _create_iteration_tags(self, dataset_tag: Optional[str]): + return _create_iteration_tags(dataset_tag) + def get_or_create_stats_actor() -> ActorHandle[_StatsActor]: """Each cluster will contain exactly 1 _StatsActor. This function @@ -1138,9 +1211,17 @@ def __init__( self.iter_finalize_batch_s: Timer = Timer() self.iter_time_to_first_batch_s: Timer = Timer() self.iter_total_blocked_s: Timer = Timer() + self.iter_blocked_fetch_s: Timer = Timer() + self.iter_blocked_batching_s: Timer = Timer() + self.iter_blocked_format_s: Timer = Timer() + self.iter_blocked_collate_s: Timer = Timer() + self.iter_blocked_finalize_s: Timer = Timer() + self.iter_blocked_restore_order_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_initialize_s: Timer = Timer() self.iter_total_s: Timer = Timer() + self.iter_batches_total: int = 0 + self.iter_rows_total: int = 0 self.extra_metrics = {} # Block fetch stats during iteration. @@ -1196,6 +1277,14 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_blocks_remote, self.iter_unknown_location, self.iter_prefetched_bytes, + self.iter_blocked_fetch_s, + self.iter_blocked_batching_s, + self.iter_blocked_format_s, + self.iter_blocked_collate_s, + self.iter_blocked_finalize_s, + self.iter_blocked_restore_order_s, + self.iter_batches_total, + self.iter_rows_total, ) stats_summary_parents = [] @@ -1878,6 +1967,16 @@ class IterStatsSummary: iter_unknown_location: int # Current bytes of prefetched blocks in the iterator iter_prefetched_bytes: int + # Per-stage training-thread blocked attribution timers. + blocked_fetch_time: Timer + blocked_batching_time: Timer + blocked_format_time: Timer + blocked_collate_time: Timer + blocked_finalize_time: Timer + blocked_restore_order_time: Timer + # Cumulative batch and row counters. + batches_total: int + rows_total: int def __str__(self) -> str: return self.to_string() @@ -1984,6 +2083,25 @@ def to_string(self) -> str: out += "Streaming split coordinator overhead time: " out += f"{fmt(self.streaming_split_coord_time.get())}\n" + # Per-stage training-thread blocked attribution. + stage_totals = [ + ("block fetch (ray.get)", self.blocked_fetch_time), + ("batching", self.blocked_batching_time), + ("format", self.blocked_format_time), + ("collate", self.blocked_collate_time), + ("finalize (host->device)", self.blocked_finalize_time), + ("restore order", self.blocked_restore_order_time), + ] + active_stages = [(name, t) for name, t in stage_totals if t.get() > 0] + if active_stages: + out += "\nPer-stage training-thread blocked time breakdown:\n" + for stage_name, timer in active_stages: + out += " * {}: {}\n".format(stage_name, fmt(timer.get())) + if self.batches_total: + out += "Total batches consumed: {}\n".format(self.batches_total) + if self.rows_total: + out += "Total rows consumed: {}\n".format(self.rows_total) + return out def __repr__(self, level=0) -> str: diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 1f69962435ad..9cfaea2f8dfc 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -22,7 +22,7 @@ from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.logical.interfaces import LogicalPlan from ray.data._internal.logical.operators import InputData -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, _StatsManager from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format from ray.data.collate_fn import ( ArrowBatchCollateFn, @@ -297,6 +297,7 @@ def callback(num_bytes: int) -> None: yield from batch_iterator if stats: stats.iter_total_s.add(time.perf_counter() - time_start) + _StatsManager.update_iteration_metrics(stats, dataset_tag) finally: # On early exit (e.g. ``break`` in the for-loop), the inner # ``_ClosingIterator`` would only shut down the executor via diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index b0b55d1cb8fb..4c6fef31a4f5 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -11,7 +11,9 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, BlockPrefetcher, + StageTiming, ) from ray.data._internal.block_batching.iter_batches import ( BatchIterator, @@ -114,6 +116,415 @@ def test_restore_from_original_order(): assert idx == [0, 1, 2, 3] +def test_restore_original_order_stats(): + stats = DatasetStats(metadata={}, parent=None) + base_iterator = [ + Batch(BatchMetadata(batch_idx=2), None), + Batch(BatchMetadata(batch_idx=0), None), + Batch(BatchMetadata(batch_idx=1), None), + ] + + ordered = list(restore_original_order(iter(base_iterator), stats=stats)) + + assert [batch.metadata.batch_idx for batch in ordered] == [0, 1, 2] + assert any( + batch.metadata.timings.restore_order.start_s > 0 + and batch.metadata.timings.restore_order.end_s + >= batch.metadata.timings.restore_order.start_s + for batch in ordered + ) + + +def test_report_batch_timings_overlap_attribution(): + stats = DatasetStats(metadata={}, parent=None) + batch_iterator = BatchIterator(iter([]), stats=stats) + timings = BatchTimings(num_rows=8) + timings.fetch = StageTiming(start_s=10.0, end_s=20.0) + timings.batching = StageTiming(start_s=20.0, end_s=30.0) + timings.format = StageTiming(start_s=30.0, end_s=40.0) + timings.finalize = StageTiming(start_s=50.0, end_s=60.0) + batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + batch_iterator._report_batch_timings( + batch, blocked_start_s=15.0, blocked_end_s=35.0 + ) + + assert stats.iter_blocked_fetch_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(10.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_collate_s.get() == 0 + assert stats.iter_blocked_finalize_s.get() == 0 + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 8 + + +def _make_batch_with_timings( + fetch_start=0.0, + fetch_end=0.0, + batching_start=0.0, + batching_end=0.0, + format_start=0.0, + format_end=0.0, + collate_start=0.0, + collate_end=0.0, + finalize_start=0.0, + finalize_end=0.0, + num_rows=0, +): + """Helper to construct a Batch with specific stage timing windows.""" + timings = BatchTimings(num_rows=num_rows) + timings.fetch = StageTiming(start_s=fetch_start, end_s=fetch_end) + timings.batching = StageTiming(start_s=batching_start, end_s=batching_end) + timings.format = StageTiming(start_s=format_start, end_s=format_end) + timings.collate = StageTiming(start_s=collate_start, end_s=collate_end) + timings.finalize = StageTiming(start_s=finalize_start, end_s=finalize_end) + return Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + +def _make_report_iterator(stats): + """Create a BatchIterator wired to the given stats without a real pipeline.""" + it = BatchIterator.__new__(BatchIterator) + it._stats = stats + return it + + +class TestReportBatchTimingsEdgeCases: + """Edge case tests for overlap-based blocked attribution.""" + + def test_zero_overlap_stage_finished_before_blocked(self): + """Fetch [0, 1.5] finished before training blocked at t=2 → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.5) + it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == 0.0 + + def test_zero_overlap_blocked_before_stage(self): + """Training blocked [0, 1], stage ran [2, 3] → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(format_start=2.0, format_end=3.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=1.0) + assert stats.iter_blocked_format_s.get() == 0.0 + + def test_partial_overlap(self): + """Fetch [0, 2], blocked [1, 3] → overlap = min(2,3)-max(0,1) = 1.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=2.0) + it._report_batch_timings(batch, blocked_start_s=1.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + + def test_full_overlap_stage_inside_blocked(self): + """Stage [1, 2] entirely inside blocked [0, 3] → full 1.0 credit.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(batching_start=1.0, batching_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + + def test_no_collate_fn_zero_attribution(self): + """collate stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(format_start=1.0, format_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == 0.0 + + def test_no_finalize_fn_zero_attribution(self): + """finalize stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(collate_start=1.0, collate_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == 0.0 + + def test_prefetch_hides_fetch_from_training(self): + """Effective prefetch: fetch done before training blocks → 0 fetch attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.5, + collate_start=2.3, + collate_end=2.6, + ) + # Training only starts blocking at t=2 (prefetch worked) + it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=2.6) + assert stats.iter_blocked_fetch_s.get() == 0.0 + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.3) + + def test_accumulation_across_batches(self): + """Two batches each contribute to fetch — values accumulate.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + # Batch 1: fetch [0,1], blocked [0,2] → overlap 1.0 + b1 = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.0, num_rows=10) + it._report_batch_timings(b1, blocked_start_s=0.0, blocked_end_s=2.0) + # Batch 2: fetch [5,6], blocked [5,7] → overlap 1.0 + b2 = _make_batch_with_timings(fetch_start=5.0, fetch_end=6.0, num_rows=20) + it._report_batch_timings(b2, blocked_start_s=5.0, blocked_end_s=7.0) + + assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 30 + + def test_overlap_invariant_sum_leq_total(self): + """sum(iter_blocked_*) <= iter_total_blocked_s always holds.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + num_rows=5, + ) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + total = stats.iter_total_blocked_s.get() + sum_stages = ( + stats.iter_blocked_fetch_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + + stats.iter_blocked_restore_order_s.get() + ) + assert sum_stages <= total + 1e-9 + + def test_restore_order_overlap(self): + """restore_order stage timing is correctly attributed.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + ) + batch.metadata.timings.restore_order = StageTiming(start_s=1.5, end_s=2.5) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_restore_order_s.get() == pytest.approx(1.0) + + def test_blocked_inside_stage(self): + """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=10.0) + it._report_batch_timings(batch, blocked_start_s=3.0, blocked_end_s=5.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + + def test_all_stages_simultaneous_overlap(self): + """Multiple stages overlap with blocked window simultaneously.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + collate_start=3.0, + collate_end=4.0, + finalize_start=4.0, + finalize_end=5.0, + num_rows=100, + ) + # Blocked window covers all stages + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(1.0) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 100 + + +class TestStageTimingRecord: + """Tests for StageTiming.record() behavior.""" + + def test_record_sets_start_and_end(self): + """First record() sets both start_s and end_s.""" + t = StageTiming() + t.record(1.0, 2.0) + assert t.start_s == 1.0 + assert t.end_s == 2.0 + + def test_record_keeps_first_start(self): + """Subsequent record() calls keep the first start_s.""" + t = StageTiming() + t.record(1.0, 2.0) + t.record(3.0, 4.0) + assert t.start_s == 1.0 # kept first start + assert t.end_s == 4.0 # updated to latest end + + def test_record_multiple_expands_window(self): + """Multiple record() calls expand the end_s window.""" + t = StageTiming() + t.record(5.0, 6.0) + t.record(7.0, 8.0) + t.record(9.0, 10.0) + assert t.start_s == 5.0 + assert t.end_s == 10.0 + + def test_record_default_values(self): + """Unrecorded StageTiming has start_s=0 and end_s=0.""" + t = StageTiming() + assert t.start_s == 0.0 + assert t.end_s == 0.0 + + +class TestMergeFetch: + """Tests for BatchTimings.merge_fetch() with multiple blocks per batch.""" + + def test_merge_single_block(self): + """Merging a single block preserves its fetch window.""" + dst = BatchTimings() + src = BatchTimings() + src.fetch = StageTiming(start_s=1.0, end_s=2.0) + dst.merge_fetch(src) + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 2.0 + + def test_merge_multiple_blocks_expands_window(self): + """Merging multiple blocks produces the union window.""" + dst = BatchTimings() + + # Block 1: fetched [1.0, 2.0] + src1 = BatchTimings() + src1.fetch = StageTiming(start_s=1.0, end_s=2.0) + dst.merge_fetch(src1) + + # Block 2: fetched [3.0, 4.0] + src2 = BatchTimings() + src2.fetch = StageTiming(start_s=3.0, end_s=4.0) + dst.merge_fetch(src2) + + # Block 3: fetched [5.0, 6.0] + src3 = BatchTimings() + src3.fetch = StageTiming(start_s=5.0, end_s=6.0) + dst.merge_fetch(src3) + + # Union: [1.0, 6.0] + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 6.0 + + def test_merge_unrecorded_block_ignored(self): + """Merging a block with no fetch timing (start_s=0) is a no-op.""" + dst = BatchTimings() + dst.fetch = StageTiming(start_s=2.0, end_s=3.0) + + src = BatchTimings() # fetch defaults to (0.0, 0.0) + dst.merge_fetch(src) + + assert dst.fetch.start_s == 2.0 + assert dst.fetch.end_s == 3.0 + + def test_merge_overlapping_blocks(self): + """Overlapping fetch windows are correctly merged.""" + dst = BatchTimings() + + src1 = BatchTimings() + src1.fetch = StageTiming(start_s=1.0, end_s=5.0) + dst.merge_fetch(src1) + + src2 = BatchTimings() + src2.fetch = StageTiming(start_s=3.0, end_s=7.0) + dst.merge_fetch(src2) + + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 7.0 + + def test_merge_into_empty_destination(self): + """Merging into an empty BatchTimings takes the source window.""" + dst = BatchTimings() # fetch = (0.0, 0.0) + src = BatchTimings() + src.fetch = StageTiming(start_s=10.0, end_s=20.0) + dst.merge_fetch(src) + assert dst.fetch.start_s == 10.0 + assert dst.fetch.end_s == 20.0 + + +class TestEndToEndTimingPropagation: + """Tests that stage timings propagate correctly through the full pipeline.""" + + def test_batch_carries_timings_through_pipeline(self): + """A Batch's metadata.timings carries all stage windows.""" + timings = BatchTimings(num_rows=50) + timings.fetch = StageTiming(start_s=1.0, end_s=2.0) + timings.batching = StageTiming(start_s=2.0, end_s=3.0) + timings.format = StageTiming(start_s=3.0, end_s=4.0) + timings.collate = StageTiming(start_s=4.0, end_s=5.0) + timings.finalize = StageTiming(start_s=5.0, end_s=6.0) + + batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + # Verify all stages are accessible via stages() iterator + stage_dict = dict(batch.metadata.timings.stages()) + assert len(stage_dict) == 6 + assert stage_dict["fetch"].start_s == 1.0 + assert stage_dict["batching"].end_s == 3.0 + assert stage_dict["format"].start_s == 3.0 + assert stage_dict["collate"].end_s == 5.0 + assert stage_dict["finalize"].start_s == 5.0 + assert stage_dict["restore_order"].start_s == 0.0 # not recorded + assert batch.metadata.timings.num_rows == 50 + + def test_full_pipeline_attribution(self): + """End-to-end: all 6 stages with realistic timing, full overlap.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=0.5, + batching_start=0.5, + batching_end=1.0, + format_start=1.0, + format_end=2.0, + collate_start=2.0, + collate_end=2.5, + finalize_start=2.5, + finalize_end=3.0, + num_rows=256, + ) + # Also set restore_order + batch.metadata.timings.restore_order = StageTiming(start_s=3.0, end_s=3.5) + + # Blocked window covers all stages + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + # Each stage gets its full duration + assert stats.iter_blocked_fetch_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_batching_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_restore_order_s.get() == pytest.approx(0.5) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 256 + + # Invariant: sum = 3.5 <= total_blocked = 5.0 + sum_stages = ( + stats.iter_blocked_fetch_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + + stats.iter_blocked_restore_order_s.get() + ) + assert sum_stages == pytest.approx(3.5) + assert sum_stages <= stats.iter_total_blocked_s.get() + 1e-9 + + def test_finalize_fn_uses_single_thread(ray_start_regular_shared): """Tests that finalize_fn is not run with multiple threads.""" ref_bundles_iter = ref_bundle_generator(num_blocks=20, num_rows=2) @@ -194,6 +605,27 @@ def collate_fn(batch: pd.DataFrame): assert concat_df["foo"].iloc[i + 1] >= concat_df["foo"].iloc[i] +def test_iter_batches_counts_rows_at_pipeline_exit(ray_start_regular_shared): + stats = DatasetStats(metadata={}, parent=None) + ref_bundles_iter = ref_bundle_generator(num_blocks=4, num_rows=2) + + output_batches = list( + BatchIterator( + ref_bundles_iter, + stats=stats, + batch_size=3, + prefetch_batches=0, + batch_format="pandas", + drop_last=True, + ) + ) + + assert len(output_batches) == 2 + assert [len(batch) for batch in output_batches] == [3, 3] + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 6 + + def test_iter_batches_e2e_async(ray_start_regular_shared): """We add time.sleep in 3 places: 1. In the base generator to simulate streaming executor blocking on next results. diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index cb4c31553541..28af6b1ae1ea 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -36,6 +36,7 @@ OperatorStatsSummary, StatsSummary, Timer, + _create_iteration_tags, _StatsActor, get_or_create_stats_actor, ) @@ -1878,6 +1879,164 @@ def test_stats_actor_iter_metrics(): assert f"dataset_{ds._uuid}_0" == update_fn.call_args_list[-1].args[1] +def test_create_iteration_tags_extracts_rank(): + assert _create_iteration_tags("train_abc_split_2") == { + "dataset": "train_abc_split_2", + "rank": "2", + } + assert _create_iteration_tags("dataset_without_split") == { + "dataset": "dataset_without_split", + "rank": "unknown", + } + # User-defined dataset name may contain split_; the trailing + # split index (from streaming split coordinator) should be used. + assert _create_iteration_tags("my_split_3_data_abc123_split_5") == { + "dataset": "my_split_3_data_abc123_split_5", + "rank": "5", + } + + +def test_update_iteration_metrics_exports_new_iter_metrics(): + stats = DatasetStats(metadata={}, parent=None) + stats.iter_total_s.add(11.0) + stats.iter_blocked_fetch_s.add(1.0) + stats.iter_blocked_batching_s.add(2.0) + stats.iter_blocked_format_s.add(3.0) + stats.iter_blocked_collate_s.add(4.0) + stats.iter_blocked_finalize_s.add(5.0) + stats.iter_blocked_restore_order_s.add(6.0) + stats.iter_batches_total = 7 + stats.iter_rows_total = 8 + + actor = _StatsActor.__ray_metadata__.modified_class.__new__( + _StatsActor.__ray_metadata__.modified_class + ) + recorded = {} + + class FakeGauge: + def __init__(self, name): + self.name = name + + def set(self, value, tags): + recorded[self.name] = (value, tags) + + for attr in [ + "iter_initialize_s", + "iter_total_s", + "iter_get_ref_bundles_s", + "iter_get_s", + "iter_next_batch_s", + "iter_format_batch_s", + "iter_collate_batch_s", + "iter_finalize_batch_s", + "iter_blocks_local", + "iter_blocks_remote", + "iter_unknown_location", + "iter_prefetched_bytes", + "iter_block_fetching_s", + "iter_batch_shaping_s", + "iter_batch_formatting_s", + "iter_batch_collating_s", + "iter_batch_finalizing_s", + "time_to_first_batch_s", + "iter_total_blocked_s", + "iter_blocked_fetch_s", + "iter_blocked_batching_s", + "iter_blocked_format_s", + "iter_blocked_collate_s", + "iter_blocked_finalize_s", + "iter_blocked_restore_order_s", + "iter_batches_total", + "iter_rows_total", + "iter_user_s", + ]: + setattr(actor, attr, FakeGauge(attr)) + + actor.update_iteration_metrics(stats, "train_dataset_split_3") + + expected_tags = {"dataset": "train_dataset_split_3", "rank": "3"} + assert recorded["iter_total_s"] == (11.0, expected_tags) + assert recorded["iter_blocked_fetch_s"] == (1.0, expected_tags) + assert recorded["iter_blocked_batching_s"] == (2.0, expected_tags) + assert recorded["iter_blocked_format_s"] == (3.0, expected_tags) + assert recorded["iter_blocked_collate_s"] == (4.0, expected_tags) + assert recorded["iter_blocked_finalize_s"] == (5.0, expected_tags) + assert recorded["iter_blocked_restore_order_s"] == (6.0, expected_tags) + assert recorded["iter_batches_total"] == (7, expected_tags) + assert recorded["iter_rows_total"] == (8, expected_tags) + + +def test_iter_stats_summary_has_new_fields(): + """IterStatsSummary includes per-stage blocked timers and counters.""" + stats = DatasetStats(metadata={}, parent=None) + summary = stats.to_summary() + iter_summary = summary.iter_stats + + assert hasattr(iter_summary, "blocked_fetch_time") + assert hasattr(iter_summary, "blocked_batching_time") + assert hasattr(iter_summary, "blocked_format_time") + assert hasattr(iter_summary, "blocked_collate_time") + assert hasattr(iter_summary, "blocked_finalize_time") + assert hasattr(iter_summary, "blocked_restore_order_time") + assert hasattr(iter_summary, "batches_total") + assert hasattr(iter_summary, "rows_total") + + +def test_iter_stats_summary_reflects_accumulated_values(): + """IterStatsSummary carries the accumulated timer values.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(0.5) + stats.iter_blocked_batching_s.add(0.2) + stats.iter_batches_total = 10 + stats.iter_rows_total = 320 + + summary = stats.to_summary().iter_stats + assert summary.blocked_fetch_time.get() == pytest.approx(0.5) + assert summary.blocked_batching_time.get() == pytest.approx(0.2) + assert summary.batches_total == 10 + assert summary.rows_total == 320 + + +def test_iter_stats_to_string_shows_stage_breakdown(): + """to_string() renders per-stage breakdown when values are non-zero.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(1.5) + stats.iter_blocked_format_s.add(0.8) + stats.iter_batches_total = 5 + stats.iter_rows_total = 160 + stats.iter_total_blocked_s.add(2.3) + + text = str(stats.to_summary().iter_stats) + assert "block fetch" in text + assert "format" in text + assert "Total batches consumed: 5" in text + assert "Total rows consumed: 160" in text + assert "Per-stage training-thread blocked time breakdown" in text + + +def test_iter_stats_to_string_omits_zero_stages(): + """to_string() omits stages with zero values from the breakdown.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(0.5) + stats.iter_total_blocked_s.add(0.5) + + text = str(stats.to_summary().iter_stats) + assert "block fetch" in text + # Zero stages should not appear + assert "batching" not in text + assert "collate" not in text + assert "restore order" not in text + + +def test_iter_stats_to_string_no_breakdown_when_all_zero(): + """When all blocked_* stages are zero, no breakdown section appears.""" + stats = DatasetStats(metadata={}, parent=None) + text = str(stats.to_summary().iter_stats) + assert "Per-stage training-thread blocked time breakdown" not in text + assert "Total batches consumed" not in text + assert "Total rows consumed" not in text + + def test_dataset_name_and_id(): # Test deprecated APIs: _set_name and _name ds = ray.data.range(1) From f3935abf291cfa6cd75026774b05bfd013c36ddf Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 20 Jun 2026 12:49:56 +0800 Subject: [PATCH 02/12] [Data] Remove unused ShufflingBatcher compaction tracking Reverts batcher.py changes that were only needed for the shuffle buffer metrics which have been removed from this PR's scope per reviewer feedback. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/batcher.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/ray/data/_internal/batcher.py b/python/ray/data/_internal/batcher.py index 7ff6ccf42c11..c097ee668de6 100644 --- a/python/ray/data/_internal/batcher.py +++ b/python/ray/data/_internal/batcher.py @@ -1,4 +1,3 @@ -import time import warnings from typing import Optional @@ -236,8 +235,6 @@ def __init__( self._total_object_store_nbytes = get_total_obj_store_mem_on_node() self._total_num_rows_added = 0 self._total_nbytes_added = 0 - self.compactions_total = 0 - self.compaction_time_s = 0.0 def add(self, block: Block): """Add a block to the shuffle buffer. @@ -323,9 +320,6 @@ def _num_rows(self) -> int: """ return self._num_compacted_rows() + self._num_uncompacted_rows() - def num_rows(self) -> int: - return self._num_rows() - def _num_compacted_rows(self) -> int: """Return number of unyielded rows in the compacted buffer.""" if self._shuffle_buffer is None: @@ -347,7 +341,6 @@ def next_batch(self) -> Block: self._done_adding or self._num_compacted_rows() <= self._min_rows_to_yield_batch ): - compaction_start_s = time.perf_counter() if self._shuffle_buffer is not None and self._batch_head < len( self._shuffled_indices ): @@ -370,8 +363,6 @@ def next_batch(self) -> Block: self._builder = DelegatingBlockBuilder() self._batch_head = 0 - self.compactions_total += 1 - self.compaction_time_s += time.perf_counter() - compaction_start_s assert self._shuffle_buffer is not None assert self._shuffled_indices is not None From 543a68d7e198d39721d295e0ba785472927305db Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:30:56 +0800 Subject: [PATCH 03/12] [Data] Remove restore_order stage from blocked attribution Per reviewer feedback, restore_order is an implementation detail rather than an actionable user-facing metric. Reverts restore_original_order() to the original simple for-loop and removes the data_iter_blocked_restore_order_seconds Prometheus metric along with all related fields, exports, and tests. The PR now exposes 8 core metrics (5 blocked stages + batches/rows/total). Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 2 - .../_internal/block_batching/iter_batches.py | 38 +++++-------------- python/ray/data/_internal/stats.py | 12 ------ .../tests/block_batching/test_iter_batches.py | 36 +++--------------- python/ray/data/tests/test_stats.py | 4 -- 5 files changed, 14 insertions(+), 78 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 10587586c941..aabb6781a33d 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -26,7 +26,6 @@ class BatchTimings: format: StageTiming = field(default_factory=StageTiming) collate: StageTiming = field(default_factory=StageTiming) finalize: StageTiming = field(default_factory=StageTiming) - restore_order: StageTiming = field(default_factory=StageTiming) num_rows: int = 0 def stages(self) -> Iterable[Tuple[str, StageTiming]]: @@ -36,7 +35,6 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: ("format", self.format), ("collate", self.collate), ("finalize", self.finalize), - ("restore_order", self.restore_order), ) def merge_fetch(self, other: "BatchTimings") -> None: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 66e55437d68f..2572f6a10dcc 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -218,7 +218,7 @@ def _finalize_batches( def _restore_original_batch_order( self, batches: Iterator[Batch] ) -> Iterator[Batch]: - return restore_original_order(batches, stats=self._stats) + return restore_original_order(batches) def _pipeline(self, ref_bundles: Iterator[RefBundle]) -> Iterator[Batch]: # Step 1: Prefetch logical batches locally. @@ -474,9 +474,7 @@ def get_next_ref_bundle() -> RefBundle: prefetcher.stop() -def restore_original_order( - batch_iter: Iterator[Batch], stats: Optional[DatasetStats] = None -) -> Iterator[Batch]: +def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: """Restores the original order of the provided `batch_iter` This function will yield items from `base_iterator` in the correct order based on @@ -487,31 +485,13 @@ def restore_original_order( """ next_index_required = 0 buffer: Dict[int, Batch] = {} - restore_wait_start_s: Optional[float] = None - source_exhausted = False - - while True: + for batch in batch_iter: + assert batch.metadata.batch_idx not in buffer + buffer[batch.metadata.batch_idx] = batch while next_index_required in buffer: - next_batch = buffer.pop(next_index_required) - if restore_wait_start_s is not None: - next_batch.metadata.timings.restore_order.record( - restore_wait_start_s, time.perf_counter() - ) - restore_wait_start_s = None - yield next_batch + yield buffer.pop(next_index_required) next_index_required += 1 - if source_exhausted: - break - - if buffer and restore_wait_start_s is None: - restore_wait_start_s = time.perf_counter() - - try: - batch = next(batch_iter) - except StopIteration: - source_exhausted = True - continue - - assert batch.metadata.batch_idx not in buffer - buffer[batch.metadata.batch_idx] = batch + while next_index_required in buffer: + yield buffer.pop(next_index_required) + next_index_required += 1 diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 7afdb967aeaa..adb06af01c2f 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -532,11 +532,6 @@ def __init__(self, max_stats=1000): description="Seconds user thread is blocked on batch finalization", tag_keys=iter_tag_keys, ) - self.iter_blocked_restore_order_s = Gauge( - "data_iter_blocked_restore_order_seconds", - description="Seconds user thread is blocked on restoring batch order", - tag_keys=iter_tag_keys, - ) self.iter_batches_total = Gauge( "data_iter_batches_total", description="Total batches delivered to the user thread", @@ -813,9 +808,6 @@ def update_iteration_metrics( self.iter_blocked_format_s.set(stats.iter_blocked_format_s.get(), tags) self.iter_blocked_collate_s.set(stats.iter_blocked_collate_s.get(), tags) self.iter_blocked_finalize_s.set(stats.iter_blocked_finalize_s.get(), tags) - self.iter_blocked_restore_order_s.set( - stats.iter_blocked_restore_order_s.get(), tags - ) self.iter_batches_total.set(stats.iter_batches_total, tags) self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) @@ -1216,7 +1208,6 @@ def __init__( self.iter_blocked_format_s: Timer = Timer() self.iter_blocked_collate_s: Timer = Timer() self.iter_blocked_finalize_s: Timer = Timer() - self.iter_blocked_restore_order_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_initialize_s: Timer = Timer() self.iter_total_s: Timer = Timer() @@ -1282,7 +1273,6 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_blocked_format_s, self.iter_blocked_collate_s, self.iter_blocked_finalize_s, - self.iter_blocked_restore_order_s, self.iter_batches_total, self.iter_rows_total, ) @@ -1973,7 +1963,6 @@ class IterStatsSummary: blocked_format_time: Timer blocked_collate_time: Timer blocked_finalize_time: Timer - blocked_restore_order_time: Timer # Cumulative batch and row counters. batches_total: int rows_total: int @@ -2090,7 +2079,6 @@ def to_string(self) -> str: ("format", self.blocked_format_time), ("collate", self.blocked_collate_time), ("finalize (host->device)", self.blocked_finalize_time), - ("restore order", self.blocked_restore_order_time), ] active_stages = [(name, t) for name, t in stage_totals if t.get() > 0] if active_stages: diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 4c6fef31a4f5..3d6e94fb0676 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -117,22 +117,15 @@ def test_restore_from_original_order(): def test_restore_original_order_stats(): - stats = DatasetStats(metadata={}, parent=None) base_iterator = [ Batch(BatchMetadata(batch_idx=2), None), Batch(BatchMetadata(batch_idx=0), None), Batch(BatchMetadata(batch_idx=1), None), ] - ordered = list(restore_original_order(iter(base_iterator), stats=stats)) + ordered = list(restore_original_order(iter(base_iterator))) assert [batch.metadata.batch_idx for batch in ordered] == [0, 1, 2] - assert any( - batch.metadata.timings.restore_order.start_s > 0 - and batch.metadata.timings.restore_order.end_s - >= batch.metadata.timings.restore_order.start_s - for batch in ordered - ) def test_report_batch_timings_overlap_attribution(): @@ -294,23 +287,9 @@ def test_overlap_invariant_sum_leq_total(self): + stats.iter_blocked_format_s.get() + stats.iter_blocked_collate_s.get() + stats.iter_blocked_finalize_s.get() - + stats.iter_blocked_restore_order_s.get() ) assert sum_stages <= total + 1e-9 - def test_restore_order_overlap(self): - """restore_order stage timing is correctly attributed.""" - stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) - batch = _make_batch_with_timings( - fetch_start=0.0, - fetch_end=1.0, - ) - batch.metadata.timings.restore_order = StageTiming(start_s=1.5, end_s=2.5) - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) - assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) - assert stats.iter_blocked_restore_order_s.get() == pytest.approx(1.0) - def test_blocked_inside_stage(self): """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" stats = DatasetStats(metadata={}, parent=None) @@ -468,17 +447,16 @@ def test_batch_carries_timings_through_pipeline(self): # Verify all stages are accessible via stages() iterator stage_dict = dict(batch.metadata.timings.stages()) - assert len(stage_dict) == 6 + assert len(stage_dict) == 5 assert stage_dict["fetch"].start_s == 1.0 assert stage_dict["batching"].end_s == 3.0 assert stage_dict["format"].start_s == 3.0 assert stage_dict["collate"].end_s == 5.0 assert stage_dict["finalize"].start_s == 5.0 - assert stage_dict["restore_order"].start_s == 0.0 # not recorded assert batch.metadata.timings.num_rows == 50 def test_full_pipeline_attribution(self): - """End-to-end: all 6 stages with realistic timing, full overlap.""" + """End-to-end: all 5 stages with realistic timing, full overlap.""" stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) stats.iter_total_blocked_s.add(5.0) @@ -496,8 +474,6 @@ def test_full_pipeline_attribution(self): finalize_end=3.0, num_rows=256, ) - # Also set restore_order - batch.metadata.timings.restore_order = StageTiming(start_s=3.0, end_s=3.5) # Blocked window covers all stages it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) @@ -508,20 +484,18 @@ def test_full_pipeline_attribution(self): assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) assert stats.iter_blocked_collate_s.get() == pytest.approx(0.5) assert stats.iter_blocked_finalize_s.get() == pytest.approx(0.5) - assert stats.iter_blocked_restore_order_s.get() == pytest.approx(0.5) assert stats.iter_batches_total == 1 assert stats.iter_rows_total == 256 - # Invariant: sum = 3.5 <= total_blocked = 5.0 + # Invariant: sum = 3.0 <= total_blocked = 5.0 sum_stages = ( stats.iter_blocked_fetch_s.get() + stats.iter_blocked_batching_s.get() + stats.iter_blocked_format_s.get() + stats.iter_blocked_collate_s.get() + stats.iter_blocked_finalize_s.get() - + stats.iter_blocked_restore_order_s.get() ) - assert sum_stages == pytest.approx(3.5) + assert sum_stages == pytest.approx(3.0) assert sum_stages <= stats.iter_total_blocked_s.get() + 1e-9 diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 28af6b1ae1ea..245c0bfed8d4 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -1904,7 +1904,6 @@ def test_update_iteration_metrics_exports_new_iter_metrics(): stats.iter_blocked_format_s.add(3.0) stats.iter_blocked_collate_s.add(4.0) stats.iter_blocked_finalize_s.add(5.0) - stats.iter_blocked_restore_order_s.add(6.0) stats.iter_batches_total = 7 stats.iter_rows_total = 8 @@ -1945,7 +1944,6 @@ def set(self, value, tags): "iter_blocked_format_s", "iter_blocked_collate_s", "iter_blocked_finalize_s", - "iter_blocked_restore_order_s", "iter_batches_total", "iter_rows_total", "iter_user_s", @@ -1961,7 +1959,6 @@ def set(self, value, tags): assert recorded["iter_blocked_format_s"] == (3.0, expected_tags) assert recorded["iter_blocked_collate_s"] == (4.0, expected_tags) assert recorded["iter_blocked_finalize_s"] == (5.0, expected_tags) - assert recorded["iter_blocked_restore_order_s"] == (6.0, expected_tags) assert recorded["iter_batches_total"] == (7, expected_tags) assert recorded["iter_rows_total"] == (8, expected_tags) @@ -1977,7 +1974,6 @@ def test_iter_stats_summary_has_new_fields(): assert hasattr(iter_summary, "blocked_format_time") assert hasattr(iter_summary, "blocked_collate_time") assert hasattr(iter_summary, "blocked_finalize_time") - assert hasattr(iter_summary, "blocked_restore_order_time") assert hasattr(iter_summary, "batches_total") assert hasattr(iter_summary, "rows_total") From 3a39cbd66c2136ecf098772f9a68d6685dfbcae2 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:50:33 +0800 Subject: [PATCH 04/12] [Data] Consolidate timing into StageTiming context manager Per reviewer feedback, consolidates the dual timing mechanism: - StageTiming now supports context manager protocol (__enter__/ __exit__) to automatically capture start_s/end_s - Timer gains start_s/end_s fields populated by timer() - Pipeline functions (resolve_block_refs, _format_batch, _collate_batch, _finalize_batch) use nested context managers instead of redundant perf_counter() + _record_stage_window() - resolve_block_refs always returns BlockWithTiming, removing the record_timings parameter, Union types, and isinstance branching - Removed _record_stage_window helper (no longer needed) Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 31 +++++-- .../_internal/block_batching/iter_batches.py | 4 +- .../ray/data/_internal/block_batching/util.py | 86 ++++++++----------- python/ray/data/_internal/stats.py | 9 +- .../tests/block_batching/test_iter_batches.py | 34 +++----- 5 files changed, 81 insertions(+), 83 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index aabb6781a33d..4f3020e65412 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,4 +1,6 @@ import abc +import time +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any, Iterable, List, Tuple @@ -8,15 +10,34 @@ @dataclass class StageTiming: - """Wall-clock window for a batch-processing stage.""" + """Wall-clock window for a single batch-processing stage. + + Can be used as a context manager to automatically capture the start and + end timestamps of a pipeline operation:: + + with stage_timing: + do_work() + # stage_timing.start_s and stage_timing.end_s are now set + """ start_s: float = 0.0 end_s: float = 0.0 - def record(self, start_s: float, end_s: float) -> None: - if self.start_s == 0.0: - self.start_s = start_s - self.end_s = end_s + def __enter__(self): + self.start_s = time.perf_counter() + return self + + def __exit__(self, *args): + self.end_s = time.perf_counter() + + @contextmanager + def timer(self): + """Alias for using as a context manager, matching Timer.timer() API.""" + self.start_s = time.perf_counter() + try: + yield + finally: + self.end_s = time.perf_counter() @dataclass diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 2572f6a10dcc..dbcfedecd748 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -176,9 +176,7 @@ def _prefetch_blocks( def _resolve_block_refs( self, block_refs: Iterator[ObjectRef[Block]] ) -> Iterator[Any]: - return resolve_block_refs( - block_ref_iter=block_refs, stats=self._stats, record_timings=True - ) + return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats) def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]: return blocks_to_batches( diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index fb9f1f3aa11a..a7dfffe9bee9 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,7 +3,6 @@ import logging import queue import threading -import time from contextlib import nullcontext from typing import ( Any, @@ -15,7 +14,6 @@ Optional, Tuple, TypeVar, - Union, ) import ray @@ -28,7 +26,6 @@ BlockPrefetcher, BlockWithTiming, CollatedBatch, - StageTiming, ) from ray.data._internal.stats import DatasetStats from ray.data.block import Block, BlockAccessor, DataBatch @@ -175,27 +172,24 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: return 0, 0, 0 -def _record_stage_window(stage: StageTiming, start_s: float, end_s: float) -> None: - stage.record(start_s, end_s) - - def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[DatasetStats] = None, - record_timings: bool = False, -) -> Iterator[Union[Block, BlockWithTiming]]: +) -> Iterator[BlockWithTiming]: """Resolves the block references for each logical batch. + Each resolved block is wrapped in a :class:`BlockWithTiming` that carries + the per-block fetch window (``start_s``/``end_s`` around ``ray.get()``). + When *stats* is provided, the cumulative fetch time is also recorded in + ``stats.iter_get_s``. + Args: block_ref_iter: An iterator over block object references. - stats: An optional stats object to recording block hits and misses. - record_timings: If True, wrap each resolved block in a - ``BlockWithTiming`` carrying the per-block fetch window. + stats: An optional stats object to record block hits, misses, and + cumulative fetch time. Yields: - Union[Block, BlockWithTiming]: The resolved blocks. When - *record_timings* is ``True`` each block is wrapped in a - ``BlockWithTiming``; otherwise raw ``Block`` instances are yielded. + BlockWithTiming: Each resolved block with its fetch timing window. """ hits = 0 misses = 0 @@ -209,16 +203,11 @@ def resolve_block_refs( # TODO(amogkam): Optimized further by batching multiple references in a single # `ray.get()` call. - start_s = time.perf_counter() - with stats.iter_get_s.timer() if stats else nullcontext(): - block = ray.get(block_ref) - end_s = time.perf_counter() - if record_timings: - timings = BatchTimings() - _record_stage_window(timings.fetch, start_s, end_s) - yield BlockWithTiming(block=block, timings=timings) - else: - yield block + timings = BatchTimings() + with timings.fetch: + with stats.iter_get_s.timer() if stats else nullcontext(): + block = ray.get(block_ref) + yield BlockWithTiming(block=block, timings=timings) if stats: stats.iter_blocks_local = hits @@ -227,7 +216,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -256,7 +245,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -294,9 +283,8 @@ def __next__(self) -> Batch: if can_yield: with timer: - start_s = time.perf_counter() - next_batch = self._batcher.next_batch() - end_s = time.perf_counter() + with self._pending_timings.batching: + next_batch = self._batcher.next_batch() res = Batch( metadata=BatchMetadata( @@ -305,7 +293,6 @@ def __next__(self) -> Batch: ), data=next_batch, ) - _record_stage_window(res.metadata.timings.batching, start_s, end_s) res.metadata.timings.num_rows = BlockAccessor.for_block( next_batch ).num_rows() @@ -318,11 +305,9 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block = next(self._block_iter) - if isinstance(block, BlockWithTiming): - self._pending_timings.merge_fetch(block.timings) - block = block.block - self._batcher.add(block) + block_with_timing = next(self._block_iter) + self._pending_timings.merge_fetch(block_with_timing.timings) + self._batcher.add(block_with_timing.block) except StopIteration: self._batcher.done_adding() self._done_adding = True @@ -341,14 +326,13 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: - start_s = time.perf_counter() - with stats.iter_format_batch_s.timer() if stats else nullcontext(): - formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( - batch_format - ) - if ensure_copy: - formatted_data = _copy_batch(formatted_data) - _record_stage_window(batch.metadata.timings.format, start_s, time.perf_counter()) + with batch.metadata.timings.format: + with stats.iter_format_batch_s.timer() if stats else nullcontext(): + formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( + batch_format + ) + if ensure_copy: + formatted_data = _copy_batch(formatted_data) return dataclasses.replace(batch, data=formatted_data) @@ -396,10 +380,9 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - start_s = time.perf_counter() - with stats.iter_collate_batch_s.timer() if stats else nullcontext(): - collated_data = collate_fn(batch.data) - _record_stage_window(batch.metadata.timings.collate, start_s, time.perf_counter()) + with batch.metadata.timings.collate: + with stats.iter_collate_batch_s.timer() if stats else nullcontext(): + collated_data = collate_fn(batch.data) return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -423,10 +406,9 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - start_s = time.perf_counter() - with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): - finalized_data = finalize_fn(batch.data) - _record_stage_window(batch.metadata.timings.finalize, start_s, time.perf_counter()) + with batch.metadata.timings.finalize: + with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): + finalized_data = finalize_fn(batch.data) return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index adb06af01c2f..dc5c32d1d500 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -195,17 +195,22 @@ def __init__(self): self._min: float = float("inf") self._max: float = 0 self._total_count: float = 0 + # Wall-clock window of the most recent timer() invocation. + # Used by overlap-based blocked attribution in iter_batches. + self.start_s: float = 0.0 + self.end_s: float = 0.0 # Bounded-memory percentile backend. add() forwards every value # to ``add_sample`` and ``percentile`` reads from it. self._distribution: DistributionTracker = DistributionTracker() @contextmanager def timer(self) -> None: - time_start = time.perf_counter() + self.start_s = time.perf_counter() try: yield finally: - self.add(time.perf_counter() - time_start) + self.end_s = time.perf_counter() + self.add(self.end_s - self.start_s) def add(self, value: float) -> None: self._total += value diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 3d6e94fb0676..3ca2c68725d8 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -329,31 +329,23 @@ def test_all_stages_simultaneous_overlap(self): class TestStageTimingRecord: """Tests for StageTiming.record() behavior.""" - def test_record_sets_start_and_end(self): - """First record() sets both start_s and end_s.""" + def test_context_manager_captures_window(self): + """Using as context manager captures start_s and end_s.""" t = StageTiming() - t.record(1.0, 2.0) - assert t.start_s == 1.0 - assert t.end_s == 2.0 + with t: + pass + assert t.start_s > 0 + assert t.end_s >= t.start_s - def test_record_keeps_first_start(self): - """Subsequent record() calls keep the first start_s.""" + def test_timer_context_manager(self): + """The timer() method works as a context manager too.""" t = StageTiming() - t.record(1.0, 2.0) - t.record(3.0, 4.0) - assert t.start_s == 1.0 # kept first start - assert t.end_s == 4.0 # updated to latest end + with t.timer(): + pass + assert t.start_s > 0 + assert t.end_s >= t.start_s - def test_record_multiple_expands_window(self): - """Multiple record() calls expand the end_s window.""" - t = StageTiming() - t.record(5.0, 6.0) - t.record(7.0, 8.0) - t.record(9.0, 10.0) - assert t.start_s == 5.0 - assert t.end_s == 10.0 - - def test_record_default_values(self): + def test_default_values(self): """Unrecorded StageTiming has start_s=0 and end_s=0.""" t = StageTiming() assert t.start_s == 0.0 From 16fea70f3ae648d15da5a2ea619f3acfb7827360 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:54:16 +0800 Subject: [PATCH 05/12] [Data] Capture upstream blocked time in fetch stage The fetch timing window in resolve_block_refs now spans from when we start waiting for the upstream iterator (blocked on the data pipeline) through ray.get() completion. This captures cross-node transfer and upstream production delays, giving a more complete picture of what blocks the training thread. Signed-off-by: OneSizeFitsQuorum --- .../ray/data/_internal/block_batching/util.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index a7dfffe9bee9..fcc6cf2a40c2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -179,7 +179,9 @@ def resolve_block_refs( """Resolves the block references for each logical batch. Each resolved block is wrapped in a :class:`BlockWithTiming` that carries - the per-block fetch window (``start_s``/``end_s`` around ``ray.get()``). + the per-block fetch window. The fetch window spans from the moment we + start waiting for the upstream iterator (blocked on the data pipeline or + cross-node transfer) until ``ray.get()`` returns the resolved block. When *stats* is provided, the cumulative fetch time is also recorded in ``stats.iter_get_s``. @@ -195,18 +197,28 @@ def resolve_block_refs( misses = 0 unknowns = 0 - for block_ref in block_ref_iter: - current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) - hits += current_hit - misses += current_miss - unknowns += current_unknown - - # TODO(amogkam): Optimized further by batching multiple references in a single - # `ray.get()` call. + while True: + # Time the upstream pull — captures blocked time waiting for the + # data pipeline to produce the next block ref. timings = BatchTimings() with timings.fetch: + try: + block_ref = next(block_ref_iter) + except StopIteration: + break + + current_hit, current_miss, current_unknown = _calculate_ref_hits( + [block_ref] + ) + hits += current_hit + misses += current_miss + unknowns += current_unknown + + # TODO(amogkam): Optimized further by batching multiple references + # in a single `ray.get()` call. with stats.iter_get_s.timer() if stats else nullcontext(): block = ray.get(block_ref) + yield BlockWithTiming(block=block, timings=timings) if stats: From 71da139f9aa60f89db049334cc720df2c3a80076 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:57:35 +0800 Subject: [PATCH 06/12] [Data] Add docstrings to timing dataclasses and _report_batch_timings Per reviewer feedback, adds clear docstrings to: - BatchTimings (per-batch pipeline-stage timing windows) - BlockWithTiming (resolved block with fetch timing) - BatchTimings.merge_fetch() (multi-block fetch window expansion) - BatchTimings.stages() (stage name/timing iterator) - _report_batch_timings() (overlap-based attribution algorithm) Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 30 ++++++++++++++++++ .../_internal/block_batching/iter_batches.py | 31 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 4f3020e65412..f0cd2ad1ff7e 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -42,6 +42,23 @@ def timer(self): @dataclass class BatchTimings: + """Per-batch pipeline-stage timing windows for overlap-based attribution. + + Each field records the ``(start_s, end_s)`` wall-clock window during which + a particular pipeline stage was active for this batch. The training thread + later compares these windows against its own blocked window to determine + how much each stage contributed to training-thread stall (see + :meth:`BatchIterator._report_batch_timings`). + + Attributes: + fetch: Waiting for upstream data production + ``ray.get()`` transfer. + batching: Assembling blocks into a batch via ``_batcher.next_batch()``. + format: Converting the batch to the requested format (numpy, pandas…). + collate: Running the user-provided ``collate_fn``. + finalize: Running the user-provided ``finalize_fn`` (e.g. host→device). + num_rows: Number of rows in this batch (for ``iter_rows_total``). + """ + fetch: StageTiming = field(default_factory=StageTiming) batching: StageTiming = field(default_factory=StageTiming) format: StageTiming = field(default_factory=StageTiming) @@ -50,6 +67,7 @@ class BatchTimings: num_rows: int = 0 def stages(self) -> Iterable[Tuple[str, StageTiming]]: + """Iterate over ``(name, timing)`` pairs for all pipeline stages.""" return ( ("fetch", self.fetch), ("batching", self.batching), @@ -59,6 +77,12 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: ) def merge_fetch(self, other: "BatchTimings") -> None: + """Expand this batch's fetch window to encompass another's. + + Used when a single batch is assembled from multiple blocks, each + fetched independently. The merged window spans from the earliest + fetch start to the latest fetch end. + """ self._merge_stage(self.fetch, other.fetch) @staticmethod @@ -73,6 +97,12 @@ def _merge_stage(dst: StageTiming, src: StageTiming) -> None: @dataclass class BlockWithTiming: + """A resolved block paired with its fetch timing window. + + Produced by :func:`resolve_block_refs` so that downstream pipeline stages + can track how long each block took to fetch (upstream wait + ``ray.get()``). + """ + block: Block timings: BatchTimings = field(default_factory=BatchTimings) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index dbcfedecd748..bb4e34d6aad4 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -264,6 +264,37 @@ def _iter_batches(self) -> Iterator[DataBatch]: def _report_batch_timings( self, batch: Batch, blocked_start_s: float, blocked_end_s: float ) -> None: + """Attribute per-stage blocked time via overlap with the training window. + + For each pipeline stage we know when it ran ``[stage.start_s, + stage.end_s]`` (recorded by background threads onto + ``batch.metadata.timings``). We also know when the training thread + was blocked ``[blocked_start_s, blocked_end_s]`` (captured in + ``_iter_batches`` around ``next()``). + + The attribution for a stage is the length of the intersection:: + + overlap = min(stage.end, blocked_end) - max(stage.start, blocked_start) + + This correctly handles all prefetch configurations: + + * Stage finished before training blocked → overlap ≤ 0 → zero credit. + * Stage fully inside blocked window → full stage duration credited. + * Partial overlap → partial credit. + + **Invariant**: ``sum(iter_blocked_*) ≤ iter_total_blocked_s``. + + Runs in the training thread; no locks needed because background + threads finished writing ``batch.metadata.timings`` before the batch + was enqueued. + + Args: + batch: The batch whose per-stage timings should be attributed. + blocked_start_s: ``perf_counter()`` value just before the + training thread called ``next(batch_iter)``. + blocked_end_s: ``perf_counter()`` value just after ``next()`` + returned. + """ if self._stats is None: return timings = batch.metadata.timings From 9fcde561a5e2f6f0b227d0096c2068cbcca8228a Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 14:13:32 +0800 Subject: [PATCH 07/12] [Data] Restore isinstance check for BlockWithTiming compatibility _BatchingIterator can receive blocks from paths other than resolve_block_refs (e.g., doctest examples that pass raw pyarrow Tables). Restore the isinstance check to handle both BlockWithTiming and raw Block objects gracefully. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/block_batching/util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index fcc6cf2a40c2..2dca971d9fb8 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -14,6 +14,7 @@ Optional, Tuple, TypeVar, + Union, ) import ray @@ -228,7 +229,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[BlockWithTiming], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -257,7 +258,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[BlockWithTiming], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -317,9 +318,11 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block_with_timing = next(self._block_iter) - self._pending_timings.merge_fetch(block_with_timing.timings) - self._batcher.add(block_with_timing.block) + block = next(self._block_iter) + if isinstance(block, BlockWithTiming): + self._pending_timings.merge_fetch(block.timings) + block = block.block + self._batcher.add(block) except StopIteration: self._batcher.done_adding() self._done_adding = True From 1266c927fe8df6238825fc6650b98d5a1865cfc0 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 14:58:26 +0800 Subject: [PATCH 08/12] [Data] Refactor to eliminate isinstance/Union in _BatchingIterator Per reviewer feedback, removed isinstance check and Union type from _BatchingIterator by ensuring all entry points wrap blocks in BlockWithTiming: - batch_blocks() now wraps raw blocks in BlockWithTiming with zero timing before passing to blocks_to_batches() - _BatchingIterator now assumes all blocks are BlockWithTiming - Removed Union import from util.py This provides a uniform type throughout the batching pipeline while maintaining backward compatibility for external callers of batch_blocks(). Signed-off-by: OneSizeFitsQuorum --- .../data/_internal/block_batching/block_batching.py | 10 +++++++++- python/ray/data/_internal/block_batching/util.py | 13 +++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index ef54a593920b..e10a212de1a6 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -1,5 +1,9 @@ from typing import Callable, Iterator, Optional, TypeVar +from ray.data._internal.block_batching.interfaces import ( + BatchTimings, + BlockWithTiming, +) from ray.data._internal.block_batching.util import ( _MappingIterator, blocks_to_batches, @@ -29,10 +33,14 @@ def batch_blocks( This function takes in an iterator of already fetched blocks. Consequently, this function doesn't support block prefetching. """ + # Wrap raw blocks in BlockWithTiming with zero timing so that + # _BatchingIterator receives a uniform type. + wrapped_blocks = (BlockWithTiming(block=b, timings=BatchTimings()) for b in blocks) + # Build the processing pipeline batch_iter = format_batches( blocks_to_batches( - block_iter=blocks, + block_iter=wrapped_blocks, stats=stats, batch_size=batch_size, drop_last=drop_last, diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 2dca971d9fb8..fcc6cf2a40c2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -14,7 +14,6 @@ Optional, Tuple, TypeVar, - Union, ) import ray @@ -229,7 +228,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -258,7 +257,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -318,11 +317,9 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block = next(self._block_iter) - if isinstance(block, BlockWithTiming): - self._pending_timings.merge_fetch(block.timings) - block = block.block - self._batcher.add(block) + block_with_timing = next(self._block_iter) + self._pending_timings.merge_fetch(block_with_timing.timings) + self._batcher.add(block_with_timing.block) except StopIteration: self._batcher.done_adding() self._done_adding = True From ff658f2b462f8d992dac833d7992c6830b5733f3 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 15:32:31 +0800 Subject: [PATCH 09/12] [Data] Fix merge_fetch idle gap and blocked window alignment Per Cursor Bugbot review: 1. merge_fetch now sums fetch durations instead of taking the span, avoiding counting idle gaps between consecutive block fetches as fetch blocking time. 2. Move blocked_start_s/blocked_end_s captures inside get_next_batch_context() so the blocked window aligns with iter_total_blocked_s, preventing sum(iter_blocked_*) from exceeding iter_total_blocked_s. Updated tests to reflect the new duration-summing behavior. Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 18 +++++++++++----- .../_internal/block_batching/iter_batches.py | 4 ++-- .../tests/block_batching/test_iter_batches.py | 21 ++++++++++--------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index f0cd2ad1ff7e..c689dd504745 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -77,13 +77,21 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: ) def merge_fetch(self, other: "BatchTimings") -> None: - """Expand this batch's fetch window to encompass another's. + """Merge fetch timings from another batch into this one. - Used when a single batch is assembled from multiple blocks, each - fetched independently. The merged window spans from the earliest - fetch start to the latest fetch end. + Sums the fetch durations rather than taking the span, to avoid + counting idle gaps between consecutive block fetches as fetch time. """ - self._merge_stage(self.fetch, other.fetch) + if other.fetch.start_s == 0.0: + return + if self.fetch.start_s == 0.0: + # First block: copy the timing + self.fetch.start_s = other.fetch.start_s + self.fetch.end_s = other.fetch.end_s + else: + # Subsequent blocks: add duration to existing span + duration = other.fetch.end_s - other.fetch.start_s + self.fetch.end_s += duration @staticmethod def _merge_stage(dst: StageTiming, src: StageTiming) -> None: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index bb4e34d6aad4..dc669c9725fd 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -248,13 +248,13 @@ def _iter_batches(self) -> Iterator[DataBatch]: self.before_epoch_start() while True: - blocked_start_s = time.perf_counter() with self.get_next_batch_context(): + blocked_start_s = time.perf_counter() try: batch = next(batch_iter) except StopIteration: break - blocked_end_s = time.perf_counter() + blocked_end_s = time.perf_counter() self._report_batch_timings(batch, blocked_start_s, blocked_end_s) with self.yield_batch_context(batch): yield batch.data diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 3ca2c68725d8..074ee23d147d 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -365,27 +365,27 @@ def test_merge_single_block(self): assert dst.fetch.end_s == 2.0 def test_merge_multiple_blocks_expands_window(self): - """Merging multiple blocks produces the union window.""" + """Merging multiple blocks sums their fetch durations.""" dst = BatchTimings() - # Block 1: fetched [1.0, 2.0] + # Block 1: fetched [1.0, 2.0] (duration: 1.0) src1 = BatchTimings() src1.fetch = StageTiming(start_s=1.0, end_s=2.0) dst.merge_fetch(src1) - # Block 2: fetched [3.0, 4.0] + # Block 2: fetched [3.0, 4.0] (duration: 1.0) src2 = BatchTimings() src2.fetch = StageTiming(start_s=3.0, end_s=4.0) dst.merge_fetch(src2) - # Block 3: fetched [5.0, 6.0] + # Block 3: fetched [5.0, 6.0] (duration: 1.0) src3 = BatchTimings() src3.fetch = StageTiming(start_s=5.0, end_s=6.0) dst.merge_fetch(src3) - # Union: [1.0, 6.0] + # Sum of durations: 1.0 + 1.0 + 1.0 = 3.0 assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 6.0 + assert dst.fetch.end_s == 4.0 # 1.0 + 3.0 def test_merge_unrecorded_block_ignored(self): """Merging a block with no fetch timing (start_s=0) is a no-op.""" @@ -399,19 +399,20 @@ def test_merge_unrecorded_block_ignored(self): assert dst.fetch.end_s == 3.0 def test_merge_overlapping_blocks(self): - """Overlapping fetch windows are correctly merged.""" + """Overlapping fetch windows sum their durations.""" dst = BatchTimings() src1 = BatchTimings() - src1.fetch = StageTiming(start_s=1.0, end_s=5.0) + src1.fetch = StageTiming(start_s=1.0, end_s=5.0) # duration: 4.0 dst.merge_fetch(src1) src2 = BatchTimings() - src2.fetch = StageTiming(start_s=3.0, end_s=7.0) + src2.fetch = StageTiming(start_s=3.0, end_s=7.0) # duration: 4.0 dst.merge_fetch(src2) + # Sum of durations: 4.0 + 4.0 = 8.0 assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 7.0 + assert dst.fetch.end_s == 9.0 # 1.0 + 8.0 def test_merge_into_empty_destination(self): """Merging into an empty BatchTimings takes the source window.""" From b9bf847d7857a50d7511cfcbebf91487d0056f3d Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 15:48:10 +0800 Subject: [PATCH 10/12] [Data] Remove unused _merge_stage method After changing merge_fetch to sum durations instead of taking the span, the _merge_stage helper is no longer called anywhere. Remove the dead code. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/block_batching/interfaces.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index c689dd504745..6292425c1cac 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -93,15 +93,6 @@ def merge_fetch(self, other: "BatchTimings") -> None: duration = other.fetch.end_s - other.fetch.start_s self.fetch.end_s += duration - @staticmethod - def _merge_stage(dst: StageTiming, src: StageTiming) -> None: - if src.start_s == 0.0: - return - if dst.start_s == 0.0 or src.start_s < dst.start_s: - dst.start_s = src.start_s - if src.end_s > dst.end_s: - dst.end_s = src.end_s - @dataclass class BlockWithTiming: From aa41de515f8f29560d9baea1fe3c5c24acda3701 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 15:59:17 +0800 Subject: [PATCH 11/12] [Data] Revert merge_fetch to span-based approach After deeper analysis, the span approach (taking [earliest_start, latest_end]) is semantically correct for multi-block fetches: - From the training thread's perspective, it's blocked for the entire span, even if there are gaps between consecutive block fetches - Those "idle gaps" are actually pipeline overhead (batching logic, scheduling) and are part of the blocking experience - Summing durations would underestimate the actual blocking time The Cursor Bugbot concern about "idle gaps" is valid in theory, but in practice: 1. The gaps are very small (microseconds of pipeline overhead) 2. They represent real blocking time from the training thread's perspective 3. Span aligns with the semantic meaning of "how long did training wait" Reverted tests to expect span behavior. Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 14 +++++++----- .../tests/block_batching/test_iter_batches.py | 22 +++++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 6292425c1cac..cfdd21e918a3 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -79,8 +79,10 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: def merge_fetch(self, other: "BatchTimings") -> None: """Merge fetch timings from another batch into this one. - Sums the fetch durations rather than taking the span, to avoid - counting idle gaps between consecutive block fetches as fetch time. + Expands the fetch window to span from the earliest block fetch start + to the latest block fetch end. This represents the total time the + training thread was blocked waiting for this batch, including any + pipeline overhead between consecutive block fetches. """ if other.fetch.start_s == 0.0: return @@ -89,9 +91,11 @@ def merge_fetch(self, other: "BatchTimings") -> None: self.fetch.start_s = other.fetch.start_s self.fetch.end_s = other.fetch.end_s else: - # Subsequent blocks: add duration to existing span - duration = other.fetch.end_s - other.fetch.start_s - self.fetch.end_s += duration + # Subsequent blocks: expand the window + if other.fetch.start_s < self.fetch.start_s: + self.fetch.start_s = other.fetch.start_s + if other.fetch.end_s > self.fetch.end_s: + self.fetch.end_s = other.fetch.end_s @dataclass diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 074ee23d147d..63128be388e3 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -365,27 +365,27 @@ def test_merge_single_block(self): assert dst.fetch.end_s == 2.0 def test_merge_multiple_blocks_expands_window(self): - """Merging multiple blocks sums their fetch durations.""" + """Merging multiple blocks produces the union window.""" dst = BatchTimings() - # Block 1: fetched [1.0, 2.0] (duration: 1.0) + # Block 1: fetched [1.0, 2.0] src1 = BatchTimings() src1.fetch = StageTiming(start_s=1.0, end_s=2.0) dst.merge_fetch(src1) - # Block 2: fetched [3.0, 4.0] (duration: 1.0) + # Block 2: fetched [3.0, 4.0] src2 = BatchTimings() src2.fetch = StageTiming(start_s=3.0, end_s=4.0) dst.merge_fetch(src2) - # Block 3: fetched [5.0, 6.0] (duration: 1.0) + # Block 3: fetched [5.0, 6.0] src3 = BatchTimings() src3.fetch = StageTiming(start_s=5.0, end_s=6.0) dst.merge_fetch(src3) - # Sum of durations: 1.0 + 1.0 + 1.0 = 3.0 + # Union: [1.0, 6.0] assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 4.0 # 1.0 + 3.0 + assert dst.fetch.end_s == 6.0 def test_merge_unrecorded_block_ignored(self): """Merging a block with no fetch timing (start_s=0) is a no-op.""" @@ -399,20 +399,20 @@ def test_merge_unrecorded_block_ignored(self): assert dst.fetch.end_s == 3.0 def test_merge_overlapping_blocks(self): - """Overlapping fetch windows sum their durations.""" + """Overlapping fetch windows are correctly merged.""" dst = BatchTimings() src1 = BatchTimings() - src1.fetch = StageTiming(start_s=1.0, end_s=5.0) # duration: 4.0 + src1.fetch = StageTiming(start_s=1.0, end_s=5.0) dst.merge_fetch(src1) src2 = BatchTimings() - src2.fetch = StageTiming(start_s=3.0, end_s=7.0) # duration: 4.0 + src2.fetch = StageTiming(start_s=3.0, end_s=7.0) dst.merge_fetch(src2) - # Sum of durations: 4.0 + 4.0 = 8.0 + # Union: [1.0, 7.0] assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 9.0 # 1.0 + 8.0 + assert dst.fetch.end_s == 7.0 def test_merge_into_empty_destination(self): """Merging into an empty BatchTimings takes the source window.""" From 7874ef45b95889390c22ca418795dd54729cb4be Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Tue, 23 Jun 2026 09:57:09 +0800 Subject: [PATCH 12/12] [Data] Fix test failures from BlockWithTiming refactor - test_util.py: Updated test_resolve_block_refs to expect BlockWithTiming objects and test_blocks_to_batches to wrap raw blocks - block_batching.py: Changed generator expression to map() to avoid holding references to blocks, fixing test_chained_transforms_release_intermediates Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/block_batching.py | 8 ++++++-- .../ray/data/tests/block_batching/test_util.py | 17 ++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index e10a212de1a6..999e2e10af9b 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -34,8 +34,12 @@ def batch_blocks( function doesn't support block prefetching. """ # Wrap raw blocks in BlockWithTiming with zero timing so that - # _BatchingIterator receives a uniform type. - wrapped_blocks = (BlockWithTiming(block=b, timings=BatchTimings()) for b in blocks) + # _BatchingIterator receives a uniform type. Use map() instead of a + # generator expression to avoid holding references to blocks. + def _wrap_block(b): + return BlockWithTiming(block=b, timings=BatchTimings()) + + wrapped_blocks = map(_wrap_block, blocks) # Build the processing pipeline batch_iter = format_batches( diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 6ead5741f0e1..7a6223b06f78 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -13,7 +13,12 @@ import pytest import ray -from ray.data._internal.block_batching.interfaces import Batch, BatchMetadata +from ray.data._internal.block_batching.interfaces import ( + Batch, + BatchMetadata, + BatchTimings, + BlockWithTiming, +) from ray.data._internal.block_batching.util import ( _calculate_ref_hits, blocks_to_batches, @@ -37,7 +42,9 @@ def test_resolve_block_refs(ray_start_regular_shared): block_refs = [ray.put(0), ray.put(1), ray.put(2)] resolved_iter = resolve_block_refs(iter(block_refs)) - assert list(resolved_iter) == [0, 1, 2] + resolved = list(resolved_iter) + assert all(isinstance(b, BlockWithTiming) for b in resolved) + assert [b.block for b in resolved] == [0, 1, 2] @pytest.mark.parametrize("block_size", [1, 10]) @@ -45,10 +52,14 @@ def test_resolve_block_refs(ray_start_regular_shared): def test_blocks_to_batches(block_size, drop_last): num_blocks = 5 block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) + # Wrap raw blocks in BlockWithTiming as blocks_to_batches now expects + wrapped_blocks = ( + BlockWithTiming(block=b, timings=BatchTimings()) for b in block_iter + ) batch_size = 3 batch_iter = list( - blocks_to_batches(block_iter, batch_size=batch_size, drop_last=drop_last) + blocks_to_batches(wrapped_blocks, batch_size=batch_size, drop_last=drop_last) ) if drop_last: