Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ when `pipeline(...)` is called. Materialization decisions are compiled into the
resolved materializer callables are frozen before `run()` starts. If it
compiles, it's valid. No runtime surprises.

### Runtime overrides on top of the compiled contract

When you need test-time swaps without patching module globals, pass
`ExecutionOverrides` to `run()` or `async_run()` and replace only the compiled
materializers you care about. The DAG shape and semantics stay fixed; only the
runtime callable changes.

### Build your own runner

The DAG compiles to a deterministic JSON contract. Write custom runners or
Expand Down
6 changes: 5 additions & 1 deletion docs/user_docs/core-concepts/build-vs-run.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ flowchart LR
| Phase | What happens | When | Output |
|---|---|---|---|
| **Build-time** | Type validation, mode resolution, materializer assignment, consumer materialization planning, circular dependency check, sync/async consistency | `pipeline(...)` is called | `Dag` object, serializable JSON |
| **Run-time** | Topological execution, lockstep streaming, bounded handoff via `max_in_flight`, `tee` forking, observer dispatch, error handling | `run()` / `async_run()` is called | Step outputs, side effects |
| **Run-time** | Topological execution, lockstep streaming, bounded handoff via `max_in_flight`, `tee` forking, observer dispatch, error handling, optional `ExecutionOverrides` on top of the compiled contract | `run()` / `async_run()` is called | Step outputs, side effects |

## Why this matters

Expand Down Expand Up @@ -67,6 +67,10 @@ All semantic decisions — mode, `max_in_flight`, `each_mode_deps`,
resolved at build time and frozen in the JSON or `Dag`. Runners don't re-infer
semantics; they execute the contract.

`ExecutionOverrides` fits inside that boundary: it can swap the concrete
runtime callable for a compiled key such as a materializer, but it does not
change graph structure, dependency resolution, or eager-vs-lazy planning.

### 2. Write your own runner

The `Dag` object is self-contained. Anyone can write a runner:
Expand Down
2 changes: 2 additions & 0 deletions synaflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
StepEvent,
)
from .core.types import OnError, StepMode, StepParams, StepResult
from .execution import ExecutionOverrides
from .execution.async_engine.executor import async_run
from .execution.sync_engine.executor import run
from .serializers import (
Expand All @@ -22,6 +23,7 @@
"include",
"run",
"async_run",
"ExecutionOverrides",
"OnError",
"StepMode",
"StepParams",
Expand Down
7 changes: 7 additions & 0 deletions synaflow/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .overrides import ExecutionOverrides, MaterializerRegistry, PipelineRegistry

__all__ = [
"ExecutionOverrides",
"MaterializerRegistry",
"PipelineRegistry",
]
68 changes: 52 additions & 16 deletions synaflow/execution/async_engine/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
OnError,
StepMode,
)
from synaflow.execution.overrides import ExecutionOverrides

from .constants import EOF_MARKER
from .iterator_utils import AsyncQueueBranch, queue_to_async_gen
Expand Down Expand Up @@ -64,28 +65,30 @@ async def _collect_async_iterator(


async def _apply_materializer(
dag: Dag, step_name: str, value: Any, consumer_type: Any = None
dag: Dag,
step_name: str,
value: Any,
materializer: Any,
consumer_type: Any = None,
) -> tuple[Any, bool, BaseException | None]:
node = dag[step_name]
mat = node.get("materializer")
if mat is None:
if materializer is None:
if isinstance(value, (AsyncIterator, AsyncGenerator, Iterator, Generator)):
items, had_error, exc = await _collect_async_iterator(dag, step_name, value)
return items, had_error, exc
return value, False, None

if inspect.iscoroutinefunction(mat):
result = await mat(value)
if inspect.iscoroutinefunction(materializer):
result = await materializer(value)
return result, False, None

if isinstance(value, (AsyncIterator, AsyncGenerator, Iterator, Generator)):
items, had_error, exc = await _collect_async_iterator(dag, step_name, value)
res = mat(items)
res = materializer(items)
if inspect.iscoroutine(res):
return await res, had_error, exc
return res, had_error, exc

res = mat(value)
res = materializer(value)
if inspect.iscoroutine(res):
return await res, False, None
return res, False, None
Expand Down Expand Up @@ -171,7 +174,11 @@ async def _safe_iterate(name: str, iterable: Any):


async def _resolve_queue(
dag: Dag, producer: str, queue: asyncio.Queue, consumer_type: Any
dag: Dag,
producer: str,
queue: asyncio.Queue,
consumer_type: Any,
materializer: Any,
) -> Any:
if consumer_type in (AsyncIterator, AsyncGenerator):
if isinstance(queue, AsyncQueueBranch):
Expand All @@ -183,7 +190,11 @@ async def _resolve_queue(
return queue
return queue_to_async_gen(queue)
result, _, _ = await _apply_materializer(
dag, producer, queue_to_async_gen(queue), consumer_type=consumer_type
dag,
producer,
queue_to_async_gen(queue),
materializer,
consumer_type=consumer_type,
)
return result

Expand All @@ -194,16 +205,28 @@ async def _resolve_queue(


class AsyncPipelineExecutor:
def __init__(self, dag: Dag, *, step_output_observers: list = None):
def __init__(
self,
dag: Dag,
*,
step_output_observers: list = None,
overrides: ExecutionOverrides | None = None,
):
self.dag = dag
self.outputs = {}
self._pump_tasks: list[asyncio.Task] = []
self._step_output_observers = step_output_observers or []
self._overrides = overrides

# ------------------------------------------------------------------
# Lifecycle observer dispatch helpers (async)
# ------------------------------------------------------------------

def _resolve_materializer(self, step_name: str, node: Any) -> Any:
if self._overrides is None:
return node.materializer
return self._overrides.materializers.resolve(step_name, node.materializer)

async def _dispatch_pipeline_event(
self,
event: PipelineEvent,
Expand Down Expand Up @@ -501,7 +524,11 @@ async def _build_arguments(self, consumer, node, unrolled):
and dep_name not in unrolled
):
dep_type = node.deps.get(dep_name)
value = await _resolve_queue(self.dag, dep_name, value, dep_type)
producer_node = self.dag[dep_name]
materializer = self._resolve_materializer(dep_name, producer_node)
value = await _resolve_queue(
self.dag, dep_name, value, dep_type, materializer
)
param = node.dataset_param_names.get(dep_name, dep_name)
args[param] = value
return args
Expand Down Expand Up @@ -537,7 +564,8 @@ def _notify_observers(self, step_name, output):
async def _materialize_with_events(
self, step_name, output, node, consumer_type=None
):
mat_name = node.materializer.__name__ if callable(node.materializer) else None
materializer = self._resolve_materializer(step_name, node)
mat_name = materializer.__name__ if callable(materializer) else None
await self._dispatch_materialization_event(
step_name,
node,
Expand All @@ -547,7 +575,11 @@ async def _materialize_with_events(
)
try:
result, had_error, exc = await _apply_materializer(
self.dag, step_name, output, consumer_type=consumer_type
self.dag,
step_name,
output,
materializer,
consumer_type=consumer_type,
)
await self._dispatch_materialization_event(
step_name,
Expand Down Expand Up @@ -764,10 +796,14 @@ async def _publish_output(self, step_name, output, node):
# ---------------------------------------------------------------------------


async def async_run(pipeline: PipelineDef, params: Any) -> None:
async def async_run(
pipeline: PipelineDef,
params: Any,
overrides: ExecutionOverrides | None = None,
) -> None:
if getattr(pipeline, "requires_sync_runner", False):
raise RuntimeError(
"This pipeline contains synchronous streams (Iterator)."
" It must be executed with run() or migrated to AsyncIterator."
)
await AsyncPipelineExecutor(pipeline.dag).execute(params)
await AsyncPipelineExecutor(pipeline.dag, overrides=overrides).execute(params)
103 changes: 103 additions & 0 deletions synaflow/execution/overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from collections.abc import Iterator, MutableMapping
from dataclasses import dataclass
from typing import Any

from synaflow.core.definition import PipelineDef


class PipelineRegistry(MutableMapping[str, Any]):
def __init__(
self,
*,
contract_keys: set[str],
fallback_values: dict[str, Any] | None = None,
) -> None:
self._contract_keys = set(contract_keys)
self._fallback_values = dict(fallback_values or {})
self._overrides: dict[str, Any] = {}

def __getitem__(self, key: str) -> Any:
self._validate_key(key)
if key in self._overrides:
return self._overrides[key]
if key in self._fallback_values:
return self._fallback_values[key]
raise KeyError(key)

def __setitem__(self, key: str, value: Any) -> None:
self._validate_key(key)
self._validate_value(key, value)
self._overrides[key] = value

def __delitem__(self, key: str) -> None:
self._validate_key(key)
if key not in self._overrides:
raise KeyError(key)
del self._overrides[key]

def __iter__(self) -> Iterator[str]:
return iter(sorted(self._contract_keys))

def __len__(self) -> int:
return len(self._contract_keys)

def resolve(self, key: str, default: Any = None) -> Any:
if key in self._overrides:
return self._overrides[key]
if key in self._fallback_values:
return self._fallback_values[key]
return default

def _validate_key(self, key: str) -> None:
if key not in self._contract_keys:
valid = ", ".join(sorted(self._contract_keys)) or "<none>"
raise KeyError(f"Unknown override key '{key}'. Valid keys: {valid}.")

def _validate_value(self, key: str, value: Any) -> None:
return None


class MaterializerRegistry(PipelineRegistry):
@classmethod
def empty(cls, pipeline: PipelineDef) -> "MaterializerRegistry":
return cls(
contract_keys=_materializer_contract_keys(pipeline),
fallback_values=_materializer_fallback_values(pipeline),
)

@classmethod
def from_production(cls, pipeline: PipelineDef) -> "MaterializerRegistry":
return cls.empty(pipeline)

def _validate_value(self, key: str, value: Any) -> None:
if not callable(value):
raise TypeError(f"Materializer override for step '{key}' must be callable.")


@dataclass(frozen=True)
class ExecutionOverrides:
materializers: MaterializerRegistry

@classmethod
def empty(cls, pipeline: PipelineDef) -> "ExecutionOverrides":
return cls(materializers=MaterializerRegistry.empty(pipeline))

@classmethod
def from_production(cls, pipeline: PipelineDef) -> "ExecutionOverrides":
return cls(materializers=MaterializerRegistry.from_production(pipeline))


def _materializer_contract_keys(pipeline: PipelineDef) -> set[str]:
return {
step_name
for step_name, node in pipeline.dag.steps.items()
if node.materializer is not None
}


def _materializer_fallback_values(pipeline: PipelineDef) -> dict[str, Any]:
return {
step_name: node.materializer
for step_name, node in pipeline.dag.steps.items()
if node.materializer is not None
}
Loading
Loading