diff --git a/README.md b/README.md index 069efc4..1b3723d 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,13 @@ when `pipeline(...)` is called. Materialization decisions are compiled into the resolved materializer callables are frozen before `run()` starts. If it compiles, it's valid. No runtime surprises. +### Runtime overrides on top of the compiled contract + +When you need test-time swaps without patching module globals, pass +`ExecutionOverrides` to `run()` or `async_run()` and replace only the compiled +materializers you care about. The DAG shape and semantics stay fixed; only the +runtime callable changes. + ### Build your own runner The DAG compiles to a deterministic JSON contract. Write custom runners or diff --git a/docs/user_docs/core-concepts/build-vs-run.md b/docs/user_docs/core-concepts/build-vs-run.md index 7d61b99..bbfa494 100644 --- a/docs/user_docs/core-concepts/build-vs-run.md +++ b/docs/user_docs/core-concepts/build-vs-run.md @@ -29,7 +29,7 @@ flowchart LR | Phase | What happens | When | Output | |---|---|---|---| | **Build-time** | Type validation, mode resolution, materializer assignment, consumer materialization planning, circular dependency check, sync/async consistency | `pipeline(...)` is called | `Dag` object, serializable JSON | -| **Run-time** | Topological execution, lockstep streaming, bounded handoff via `max_in_flight`, `tee` forking, observer dispatch, error handling | `run()` / `async_run()` is called | Step outputs, side effects | +| **Run-time** | Topological execution, lockstep streaming, bounded handoff via `max_in_flight`, `tee` forking, observer dispatch, error handling, optional `ExecutionOverrides` on top of the compiled contract | `run()` / `async_run()` is called | Step outputs, side effects | ## Why this matters @@ -67,6 +67,10 @@ All semantic decisions — mode, `max_in_flight`, `each_mode_deps`, resolved at build time and frozen in the JSON or `Dag`. Runners don't re-infer semantics; they execute the contract. +`ExecutionOverrides` fits inside that boundary: it can swap the concrete +runtime callable for a compiled key such as a materializer, but it does not +change graph structure, dependency resolution, or eager-vs-lazy planning. + ### 2. Write your own runner The `Dag` object is self-contained. Anyone can write a runner: diff --git a/synaflow/__init__.py b/synaflow/__init__.py index 6bd5610..685852b 100644 --- a/synaflow/__init__.py +++ b/synaflow/__init__.py @@ -6,6 +6,7 @@ StepEvent, ) from .core.types import OnError, StepMode, StepParams, StepResult +from .execution import ExecutionOverrides from .execution.async_engine.executor import async_run from .execution.sync_engine.executor import run from .serializers import ( @@ -22,6 +23,7 @@ "include", "run", "async_run", + "ExecutionOverrides", "OnError", "StepMode", "StepParams", diff --git a/synaflow/execution/__init__.py b/synaflow/execution/__init__.py index e69de29..826ce65 100644 --- a/synaflow/execution/__init__.py +++ b/synaflow/execution/__init__.py @@ -0,0 +1,7 @@ +from .overrides import ExecutionOverrides, MaterializerRegistry, PipelineRegistry + +__all__ = [ + "ExecutionOverrides", + "MaterializerRegistry", + "PipelineRegistry", +] diff --git a/synaflow/execution/async_engine/executor.py b/synaflow/execution/async_engine/executor.py index 45f5c8d..e85bbc1 100644 --- a/synaflow/execution/async_engine/executor.py +++ b/synaflow/execution/async_engine/executor.py @@ -25,6 +25,7 @@ OnError, StepMode, ) +from synaflow.execution.overrides import ExecutionOverrides from .constants import EOF_MARKER from .iterator_utils import AsyncQueueBranch, queue_to_async_gen @@ -64,28 +65,30 @@ async def _collect_async_iterator( async def _apply_materializer( - dag: Dag, step_name: str, value: Any, consumer_type: Any = None + dag: Dag, + step_name: str, + value: Any, + materializer: Any, + consumer_type: Any = None, ) -> tuple[Any, bool, BaseException | None]: - node = dag[step_name] - mat = node.get("materializer") - if mat is None: + if materializer is None: if isinstance(value, (AsyncIterator, AsyncGenerator, Iterator, Generator)): items, had_error, exc = await _collect_async_iterator(dag, step_name, value) return items, had_error, exc return value, False, None - if inspect.iscoroutinefunction(mat): - result = await mat(value) + if inspect.iscoroutinefunction(materializer): + result = await materializer(value) return result, False, None if isinstance(value, (AsyncIterator, AsyncGenerator, Iterator, Generator)): items, had_error, exc = await _collect_async_iterator(dag, step_name, value) - res = mat(items) + res = materializer(items) if inspect.iscoroutine(res): return await res, had_error, exc return res, had_error, exc - res = mat(value) + res = materializer(value) if inspect.iscoroutine(res): return await res, False, None return res, False, None @@ -171,7 +174,11 @@ async def _safe_iterate(name: str, iterable: Any): async def _resolve_queue( - dag: Dag, producer: str, queue: asyncio.Queue, consumer_type: Any + dag: Dag, + producer: str, + queue: asyncio.Queue, + consumer_type: Any, + materializer: Any, ) -> Any: if consumer_type in (AsyncIterator, AsyncGenerator): if isinstance(queue, AsyncQueueBranch): @@ -183,7 +190,11 @@ async def _resolve_queue( return queue return queue_to_async_gen(queue) result, _, _ = await _apply_materializer( - dag, producer, queue_to_async_gen(queue), consumer_type=consumer_type + dag, + producer, + queue_to_async_gen(queue), + materializer, + consumer_type=consumer_type, ) return result @@ -194,16 +205,28 @@ async def _resolve_queue( class AsyncPipelineExecutor: - def __init__(self, dag: Dag, *, step_output_observers: list = None): + def __init__( + self, + dag: Dag, + *, + step_output_observers: list = None, + overrides: ExecutionOverrides | None = None, + ): self.dag = dag self.outputs = {} self._pump_tasks: list[asyncio.Task] = [] self._step_output_observers = step_output_observers or [] + self._overrides = overrides # ------------------------------------------------------------------ # Lifecycle observer dispatch helpers (async) # ------------------------------------------------------------------ + def _resolve_materializer(self, step_name: str, node: Any) -> Any: + if self._overrides is None: + return node.materializer + return self._overrides.materializers.resolve(step_name, node.materializer) + async def _dispatch_pipeline_event( self, event: PipelineEvent, @@ -501,7 +524,11 @@ async def _build_arguments(self, consumer, node, unrolled): and dep_name not in unrolled ): dep_type = node.deps.get(dep_name) - value = await _resolve_queue(self.dag, dep_name, value, dep_type) + producer_node = self.dag[dep_name] + materializer = self._resolve_materializer(dep_name, producer_node) + value = await _resolve_queue( + self.dag, dep_name, value, dep_type, materializer + ) param = node.dataset_param_names.get(dep_name, dep_name) args[param] = value return args @@ -537,7 +564,8 @@ def _notify_observers(self, step_name, output): async def _materialize_with_events( self, step_name, output, node, consumer_type=None ): - mat_name = node.materializer.__name__ if callable(node.materializer) else None + materializer = self._resolve_materializer(step_name, node) + mat_name = materializer.__name__ if callable(materializer) else None await self._dispatch_materialization_event( step_name, node, @@ -547,7 +575,11 @@ async def _materialize_with_events( ) try: result, had_error, exc = await _apply_materializer( - self.dag, step_name, output, consumer_type=consumer_type + self.dag, + step_name, + output, + materializer, + consumer_type=consumer_type, ) await self._dispatch_materialization_event( step_name, @@ -764,10 +796,14 @@ async def _publish_output(self, step_name, output, node): # --------------------------------------------------------------------------- -async def async_run(pipeline: PipelineDef, params: Any) -> None: +async def async_run( + pipeline: PipelineDef, + params: Any, + overrides: ExecutionOverrides | None = None, +) -> None: if getattr(pipeline, "requires_sync_runner", False): raise RuntimeError( "This pipeline contains synchronous streams (Iterator)." " It must be executed with run() or migrated to AsyncIterator." ) - await AsyncPipelineExecutor(pipeline.dag).execute(params) + await AsyncPipelineExecutor(pipeline.dag, overrides=overrides).execute(params) diff --git a/synaflow/execution/overrides.py b/synaflow/execution/overrides.py new file mode 100644 index 0000000..28174b9 --- /dev/null +++ b/synaflow/execution/overrides.py @@ -0,0 +1,103 @@ +from collections.abc import Iterator, MutableMapping +from dataclasses import dataclass +from typing import Any + +from synaflow.core.definition import PipelineDef + + +class PipelineRegistry(MutableMapping[str, Any]): + def __init__( + self, + *, + contract_keys: set[str], + fallback_values: dict[str, Any] | None = None, + ) -> None: + self._contract_keys = set(contract_keys) + self._fallback_values = dict(fallback_values or {}) + self._overrides: dict[str, Any] = {} + + def __getitem__(self, key: str) -> Any: + self._validate_key(key) + if key in self._overrides: + return self._overrides[key] + if key in self._fallback_values: + return self._fallback_values[key] + raise KeyError(key) + + def __setitem__(self, key: str, value: Any) -> None: + self._validate_key(key) + self._validate_value(key, value) + self._overrides[key] = value + + def __delitem__(self, key: str) -> None: + self._validate_key(key) + if key not in self._overrides: + raise KeyError(key) + del self._overrides[key] + + def __iter__(self) -> Iterator[str]: + return iter(sorted(self._contract_keys)) + + def __len__(self) -> int: + return len(self._contract_keys) + + def resolve(self, key: str, default: Any = None) -> Any: + if key in self._overrides: + return self._overrides[key] + if key in self._fallback_values: + return self._fallback_values[key] + return default + + def _validate_key(self, key: str) -> None: + if key not in self._contract_keys: + valid = ", ".join(sorted(self._contract_keys)) or "" + raise KeyError(f"Unknown override key '{key}'. Valid keys: {valid}.") + + def _validate_value(self, key: str, value: Any) -> None: + return None + + +class MaterializerRegistry(PipelineRegistry): + @classmethod + def empty(cls, pipeline: PipelineDef) -> "MaterializerRegistry": + return cls( + contract_keys=_materializer_contract_keys(pipeline), + fallback_values=_materializer_fallback_values(pipeline), + ) + + @classmethod + def from_production(cls, pipeline: PipelineDef) -> "MaterializerRegistry": + return cls.empty(pipeline) + + def _validate_value(self, key: str, value: Any) -> None: + if not callable(value): + raise TypeError(f"Materializer override for step '{key}' must be callable.") + + +@dataclass(frozen=True) +class ExecutionOverrides: + materializers: MaterializerRegistry + + @classmethod + def empty(cls, pipeline: PipelineDef) -> "ExecutionOverrides": + return cls(materializers=MaterializerRegistry.empty(pipeline)) + + @classmethod + def from_production(cls, pipeline: PipelineDef) -> "ExecutionOverrides": + return cls(materializers=MaterializerRegistry.from_production(pipeline)) + + +def _materializer_contract_keys(pipeline: PipelineDef) -> set[str]: + return { + step_name + for step_name, node in pipeline.dag.steps.items() + if node.materializer is not None + } + + +def _materializer_fallback_values(pipeline: PipelineDef) -> dict[str, Any]: + return { + step_name: node.materializer + for step_name, node in pipeline.dag.steps.items() + if node.materializer is not None + } diff --git a/synaflow/execution/sync_engine/executor.py b/synaflow/execution/sync_engine/executor.py index 0560b30..fd8afcf 100644 --- a/synaflow/execution/sync_engine/executor.py +++ b/synaflow/execution/sync_engine/executor.py @@ -27,6 +27,7 @@ OnError, StepMode, ) +from synaflow.execution.overrides import ExecutionOverrides from synaflow.execution.sync_handoff import ( SyncFanout, SyncMaterializedValue, @@ -65,11 +66,13 @@ def _collect_iterator( def _apply_materializer( - dag: Dag, step_name: str, value: Any, consumer_type: Any = None + dag: Dag, + step_name: str, + value: Any, + materializer: Any, + consumer_type: Any = None, ) -> tuple[Any, bool, BaseException | None]: - node = dag[step_name] - mat = node.get("materializer") - if mat is None: + if materializer is None: if isinstance(value, Iterator): items, had_error, exc = _collect_iterator(dag, step_name, value) return items, had_error, exc @@ -77,9 +80,9 @@ def _apply_materializer( if isinstance(value, Iterator): items, had_error, exc = _collect_iterator(dag, step_name, value) - return mat(items), had_error, exc + return materializer(items), had_error, exc - return mat(value), False, None + return materializer(value), False, None def _handle_error(dag: Dag, step_name: str, exc: BaseException) -> None: @@ -103,12 +106,19 @@ def _handle_error(dag: Dag, step_name: str, exc: BaseException) -> None: class PipelineExecutor: - def __init__(self, dag: Dag, *, step_output_observers: list = None): + def __init__( + self, + dag: Dag, + *, + step_output_observers: list = None, + overrides: ExecutionOverrides | None = None, + ): self.dag = dag self.outputs = {} self._step_output_observers = step_output_observers or [] self._active_fanouts: list[SyncFanout] = [] self._observer_threads: list[threading.Thread] = [] + self._overrides = overrides # ------------------------------------------------------------------ # Lifecycle observer dispatch helpers @@ -237,6 +247,11 @@ def _dispatch_materialization_event( # Execution # ------------------------------------------------------------------ + def _resolve_materializer(self, step_name: str, node: Any) -> Any: + if self._overrides is None: + return node.materializer + return self._overrides.materializers.resolve(step_name, node.materializer) + def execute(self, params: Any) -> None: for field, value in params._asdict().items(): self.outputs[field] = value @@ -500,7 +515,8 @@ def run_observer(obs=observer, branch_iter=iterator): self._observer_threads.append(thread) def _materialize_with_events(self, step_name, output, node, consumer_type=None): - mat_name = node.materializer.__name__ if callable(node.materializer) else None + materializer = self._resolve_materializer(step_name, node) + mat_name = materializer.__name__ if callable(materializer) else None self._dispatch_materialization_event( step_name, node, @@ -510,7 +526,11 @@ def _materialize_with_events(self, step_name, output, node, consumer_type=None): ) try: result, had_error, exc = _apply_materializer( - self.dag, step_name, output, consumer_type=consumer_type + self.dag, + step_name, + output, + materializer, + consumer_type=consumer_type, ) self._dispatch_materialization_event( step_name, @@ -747,10 +767,12 @@ def _publish_output(self, step_name, output, node): # --------------------------------------------------------------------------- -def run(pipeline: PipelineDef, params: Any) -> None: +def run( + pipeline: PipelineDef, params: Any, overrides: ExecutionOverrides | None = None +) -> None: if getattr(pipeline, "requires_async_runner", False): raise RuntimeError( "This pipeline contains async features (async def or AsyncIterator)" " and must be executed with async_run()." ) - PipelineExecutor(pipeline.dag).execute(params) + PipelineExecutor(pipeline.dag, overrides=overrides).execute(params) diff --git a/tests/execution/test_execution_overrides.py b/tests/execution/test_execution_overrides.py new file mode 100644 index 0000000..b350fb4 --- /dev/null +++ b/tests/execution/test_execution_overrides.py @@ -0,0 +1,135 @@ +from typing import AsyncGenerator, Iterator, NamedTuple + +import pytest + +from synaflow import ExecutionOverrides, async_run, pipeline, step + + +def test_given_materializer_override_when_sync_run_then_override_is_used( + run_pipeline, +): + class Params(NamedTuple): + count: int = 3 + + def gen(count: int) -> Iterator[int]: + yield from range(count) + + captured = [] + + def consume(items: list[int]) -> None: + captured.append(items) + + p = pipeline( + name="sync_override", + params=Params, + steps=[ + step("items", fn=gen), + step("consume", fn=consume), + ], + ) + + overrides = ExecutionOverrides.empty(p) + overrides.materializers["items"] = tuple + + run_pipeline(p, Params(), overrides=overrides) + + assert captured == [(0, 1, 2)] + + +async def test_given_materializer_override_when_async_run_then_override_is_used(): + class Params(NamedTuple): + count: int = 3 + + async def gen(count: int) -> AsyncGenerator[int, None]: + for item in range(count): + yield item + + captured = [] + + async def consume(items: list[int]) -> None: + captured.append(items) + + p = pipeline( + name="async_override", + params=Params, + steps=[ + step("items", fn=gen), + step("consume", fn=consume), + ], + ) + + overrides = ExecutionOverrides.empty(p) + overrides.materializers["items"] = tuple + + await async_run(p, Params(), overrides=overrides) + + assert captured == [(0, 1, 2)] + + +def test_given_execution_overrides_from_production_when_materializer_requested_then_returns_compiled_callable(): + class Params(NamedTuple): + count: int = 1 + + def gen(count: int) -> Iterator[int]: + yield from range(count) + + def consume(items: list[int]) -> None: + return None + + p = pipeline( + name="compiled_materializer_contract", + params=Params, + steps=[ + step("items", fn=gen), + step("consume", fn=consume), + ], + ) + + overrides = ExecutionOverrides.from_production(p) + + assert list(overrides.materializers) == ["items"] + assert overrides.materializers["items"] is list + + +def test_given_unknown_materializer_override_key_when_assigned_then_raises(): + class Params(NamedTuple): + value: int = 1 + + def emit(value: int) -> int: + return value + + p = pipeline( + name="invalid_override_key", + params=Params, + steps=[step("emit", fn=emit)], + ) + + overrides = ExecutionOverrides.empty(p) + + with pytest.raises(KeyError, match="Unknown override key 'missing'"): + overrides.materializers["missing"] = tuple + + +def test_given_non_callable_materializer_override_when_assigned_then_raises(): + class Params(NamedTuple): + count: int = 1 + + def gen(count: int) -> Iterator[int]: + yield from range(count) + + def consume(items: list[int]) -> None: + return None + + p = pipeline( + name="invalid_override_value", + params=Params, + steps=[ + step("items", fn=gen), + step("consume", fn=consume), + ], + ) + + overrides = ExecutionOverrides.empty(p) + + with pytest.raises(TypeError, match="must be callable"): + overrides.materializers["items"] = 123