Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion python/ray/data/_internal/block_batching/block_batching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Callable, Iterator, Optional, TypeVar

from ray.data._internal.block_batching.interfaces import (
BatchTimings,
BlockWithTiming,
)
from ray.data._internal.block_batching.util import (
_MappingIterator,
blocks_to_batches,
Expand Down Expand Up @@ -29,10 +33,18 @@ def batch_blocks(
This function takes in an iterator of already fetched blocks. Consequently, this
function doesn't support block prefetching.
"""
# Wrap raw blocks in BlockWithTiming with zero timing so that
# _BatchingIterator receives a uniform type. Use map() instead of a
# generator expression to avoid holding references to blocks.
def _wrap_block(b):
return BlockWithTiming(block=b, timings=BatchTimings())

wrapped_blocks = map(_wrap_block, blocks)

# Build the processing pipeline
batch_iter = format_batches(
blocks_to_batches(
block_iter=blocks,
block_iter=wrapped_blocks,
stats=stats,
batch_size=batch_size,
drop_last=drop_last,
Expand Down
110 changes: 108 additions & 2 deletions python/ray/data/_internal/block_batching/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,127 @@
import abc
from dataclasses import dataclass
from typing import Any, List
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Tuple

from ray.data.block import Block, DataBatch
from ray.types import ObjectRef


@dataclass
class StageTiming:
"""Wall-clock window for a single batch-processing stage.
Can be used as a context manager to automatically capture the start and
end timestamps of a pipeline operation::
with stage_timing:
do_work()
# stage_timing.start_s and stage_timing.end_s are now set
"""

start_s: float = 0.0
end_s: float = 0.0

def __enter__(self):
self.start_s = time.perf_counter()
return self

def __exit__(self, *args):
self.end_s = time.perf_counter()

@contextmanager
def timer(self):
"""Alias for using as a context manager, matching Timer.timer() API."""
self.start_s = time.perf_counter()
try:
yield
finally:
self.end_s = time.perf_counter()


@dataclass
class BatchTimings:
"""Per-batch pipeline-stage timing windows for overlap-based attribution.
Each field records the ``(start_s, end_s)`` wall-clock window during which
a particular pipeline stage was active for this batch. The training thread
later compares these windows against its own blocked window to determine
how much each stage contributed to training-thread stall (see
:meth:`BatchIterator._report_batch_timings`).
Attributes:
fetch: Waiting for upstream data production + ``ray.get()`` transfer.
batching: Assembling blocks into a batch via ``_batcher.next_batch()``.
format: Converting the batch to the requested format (numpy, pandas…).
collate: Running the user-provided ``collate_fn``.
finalize: Running the user-provided ``finalize_fn`` (e.g. host→device).
num_rows: Number of rows in this batch (for ``iter_rows_total``).
"""

fetch: StageTiming = field(default_factory=StageTiming)
batching: StageTiming = field(default_factory=StageTiming)
format: StageTiming = field(default_factory=StageTiming)
collate: StageTiming = field(default_factory=StageTiming)
finalize: StageTiming = field(default_factory=StageTiming)
num_rows: int = 0

def stages(self) -> Iterable[Tuple[str, StageTiming]]:
"""Iterate over ``(name, timing)`` pairs for all pipeline stages."""
return (
("fetch", self.fetch),
("batching", self.batching),
("format", self.format),
("collate", self.collate),
("finalize", self.finalize),
)

def merge_fetch(self, other: "BatchTimings") -> None:
"""Merge fetch timings from another batch into this one.
Expands the fetch window to span from the earliest block fetch start
to the latest block fetch end. This represents the total time the
training thread was blocked waiting for this batch, including any
pipeline overhead between consecutive block fetches.
"""
if other.fetch.start_s == 0.0:
return
if self.fetch.start_s == 0.0:
# First block: copy the timing
self.fetch.start_s = other.fetch.start_s
self.fetch.end_s = other.fetch.end_s
else:
# Subsequent blocks: expand the window
if other.fetch.start_s < self.fetch.start_s:
self.fetch.start_s = other.fetch.start_s
if other.fetch.end_s > self.fetch.end_s:
self.fetch.end_s = other.fetch.end_s


@dataclass
class BlockWithTiming:
"""A resolved block paired with its fetch timing window.
Produced by :func:`resolve_block_refs` so that downstream pipeline stages
can track how long each block took to fetch (upstream wait + ``ray.get()``).
"""

block: Block
timings: BatchTimings = field(default_factory=BatchTimings)


@dataclass
class BatchMetadata:
"""Metadata associated with a batch.
Attributes:
batch_idx: The global index of this batch so that downstream operations can
maintain ordering.
timings: Pipeline-stage timing windows for this batch.
"""

batch_idx: int
timings: BatchTimings = field(default_factory=BatchTimings)


@dataclass
Expand Down
53 changes: 52 additions & 1 deletion python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _prefetch_blocks(

def _resolve_block_refs(
self, block_refs: Iterator[ObjectRef[Block]]
) -> Iterator[Block]:
) -> Iterator[Any]:
return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats)

def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]:
Expand Down Expand Up @@ -249,15 +249,66 @@ def _iter_batches(self) -> Iterator[DataBatch]:

while True:
with self.get_next_batch_context():
blocked_start_s = time.perf_counter()
try:
batch = next(batch_iter)
except StopIteration:
break
blocked_end_s = time.perf_counter()
self._report_batch_timings(batch, blocked_start_s, blocked_end_s)
with self.yield_batch_context(batch):
yield batch.data

self.after_epoch_end()

def _report_batch_timings(
Comment thread
OneSizeFitsQuorum marked this conversation as resolved.
self, batch: Batch, blocked_start_s: float, blocked_end_s: float
) -> None:
"""Attribute per-stage blocked time via overlap with the training window.

For each pipeline stage we know when it ran ``[stage.start_s,
stage.end_s]`` (recorded by background threads onto
``batch.metadata.timings``). We also know when the training thread
was blocked ``[blocked_start_s, blocked_end_s]`` (captured in
``_iter_batches`` around ``next()``).

The attribution for a stage is the length of the intersection::

overlap = min(stage.end, blocked_end) - max(stage.start, blocked_start)

This correctly handles all prefetch configurations:

* Stage finished before training blocked → overlap ≤ 0 → zero credit.
* Stage fully inside blocked window → full stage duration credited.
* Partial overlap → partial credit.

**Invariant**: ``sum(iter_blocked_*) ≤ iter_total_blocked_s``.

Runs in the training thread; no locks needed because background
threads finished writing ``batch.metadata.timings`` before the batch
was enqueued.

Args:
batch: The batch whose per-stage timings should be attributed.
blocked_start_s: ``perf_counter()`` value just before the
training thread called ``next(batch_iter)``.
blocked_end_s: ``perf_counter()`` value just after ``next()``
returned.
"""
if self._stats is None:
return
timings = batch.metadata.timings
for name, stage in timings.stages():
if stage.start_s == 0.0 and stage.end_s == 0.0:
continue
overlap_s = min(stage.end_s, blocked_end_s) - max(
stage.start_s, blocked_start_s
)
if overlap_s > 0:
getattr(self._stats, f"iter_blocked_{name}_s").add(overlap_s)
self._stats.iter_batches_total += 1
self._stats.iter_rows_total += timings.num_rows

def __iter__(self) -> Iterator[DataBatch]:
return self._iter_batches()

Expand Down
Loading
Loading