Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,24 @@ compiles, it's valid. No runtime surprises.
### Runtime overrides on top of the compiled contract

When you need test-time swaps without patching module globals, pass
`ExecutionOverrides` to `run()` or `async_run()` and replace only the compiled
materializers you care about. The DAG shape and semantics stay fixed; only the
runtime callable changes.
`ExecutionOverrides` to `run()` or `async_run()` and replace only compiled
runtime dependencies such as materializers or observers. Use
`PIPELINE_SCOPE` for pipeline-level observers. The DAG shape and semantics stay
fixed; only the runtime callable changes.

```python
from synaflow import ExecutionOverrides, Observer, PIPELINE_SCOPE, Scope

overrides = ExecutionOverrides.empty(p)
sub = Scope("payments")

overrides.observers[PIPELINE_SCOPE] = [Observer(noop_metrics)]
overrides.observers[sub.scope("validate")] = [Observer(test_recorder)]
overrides.materializers[sub.scope("normalize")] = list
```

For included sub-pipelines, `Scope(...)` is the public helper for addressing
compiled step keys without hardcoding `"payments__validate"` by hand.

### Build your own runner

Expand Down
23 changes: 21 additions & 2 deletions docs/user_docs/core-concepts/build-vs-run.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,27 @@ resolved at build time and frozen in the JSON or `Dag`. Runners don't re-infer
semantics; they execute the contract.

`ExecutionOverrides` fits inside that boundary: it can swap the concrete
runtime callable for a compiled key such as a materializer, but it does not
change graph structure, dependency resolution, or eager-vs-lazy planning.
runtime callable for a compiled key such as a materializer or observer scope,
but it does not change graph structure, dependency resolution, or eager-vs-lazy
planning.

For nested pipelines, the public key helper is `Scope`, not manual string
concatenation:

```python
from synaflow import ExecutionOverrides, Observer, PIPELINE_SCOPE, Scope

overrides = ExecutionOverrides.empty(p)
sub = Scope("incl")

overrides.observers[PIPELINE_SCOPE] = [Observer(noop)]
overrides.observers[sub.scope("validate")] = [Observer(spy)]
overrides.materializers[sub.scope("prepare")] = tuple
```

The executor never understands sub-pipelines directly. `Scope` resolves to the
compiled DAG step key before execution starts, so runtime still operates on the
same flat compiled contract.

### 2. Write your own runner

Expand Down
4 changes: 4 additions & 0 deletions synaflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .core.definition import include, pipeline, step
from .core.constants import PIPELINE_SCOPE
from .core.naming import Scope
from .core.observers import (
MaterializationEvent,
Observer,
Expand All @@ -21,6 +23,8 @@
"pipeline",
"step",
"include",
"PIPELINE_SCOPE",
"Scope",
"run",
"async_run",
"ExecutionOverrides",
Expand Down
1 change: 1 addition & 0 deletions synaflow/core/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PIPELINE_SCOPE = "__pipeline__"
28 changes: 28 additions & 0 deletions synaflow/core/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,41 @@
plural, suffixed) without manual wiring.
"""

from dataclasses import dataclass

import inflect

_engine = inflect.engine()

_SUFFIXES = {"_list", "_set", "_dict", "_tuple"}


@dataclass(frozen=True)
class Scope:
parts: tuple[str, ...]

def __init__(self, *parts: str):
normalized = tuple(_validate_scope_part(part) for part in parts)
if not normalized:
raise ValueError("Scope requires at least one non-empty part.")
object.__setattr__(self, "parts", normalized)

def scope(self, part: str) -> "Scope":
return Scope(*self.parts, part)

def __call__(self, step_name: str) -> str:
return str(self.scope(step_name))

def __str__(self) -> str:
return "__".join(self.parts)


def _validate_scope_part(part: str) -> str:
if not isinstance(part, str) or not part:
raise ValueError("Scope parts must be non-empty strings.")
return part


def get_base_dataset_name(name: str) -> str:
"""Return the absolute plural Base Dataset name.

Expand Down
5 changes: 4 additions & 1 deletion synaflow/core/type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
def is_factory(func: Callable) -> bool:
if not callable(func):
return False
sig = inspect.signature(func)
try:
sig = inspect.signature(func)
except (TypeError, ValueError):
return False
for param in sig.parameters.values():
if param.name in ("ctx", "context") or "MaterializeContext" in str(
param.annotation
Expand Down
8 changes: 7 additions & 1 deletion synaflow/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from .overrides import ExecutionOverrides, MaterializerRegistry, PipelineRegistry
from .overrides import (
ExecutionOverrides,
MaterializerRegistry,
ObserverRegistry,
PipelineRegistry,
)

__all__ = [
"ExecutionOverrides",
"MaterializerRegistry",
"ObserverRegistry",
"PipelineRegistry",
]
23 changes: 20 additions & 3 deletions synaflow/execution/async_engine/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import AsyncGenerator, AsyncIterator, Generator, Iterator
from typing import Any

from synaflow.core.constants import PIPELINE_SCOPE
from synaflow.core.dag import Dag
from synaflow.core.definition import PipelineDef
from synaflow.core.exceptions import PipelineStopException, StepExecutionError
Expand Down Expand Up @@ -227,13 +228,29 @@ def _resolve_materializer(self, step_name: str, node: Any) -> Any:
return node.materializer
return self._overrides.materializers.resolve(step_name, node.materializer)

def _resolve_pipeline_observers(self) -> list:
if self._overrides is None:
return self.dag.pipeline_observers
return self._overrides.observers.resolve(
PIPELINE_SCOPE, self.dag.pipeline_observers
)

def _resolve_step_observers(self, node: Any, step_name: str) -> list:
pipeline_observers = self._resolve_pipeline_observers()
step_observers = [obs for obs in node.observers if obs.source == "step"]
if self._overrides is not None:
step_observers = self._overrides.observers.resolve(
step_name, step_observers
)
return [*pipeline_observers, *step_observers]

async def _dispatch_pipeline_event(
self,
event: PipelineEvent,
step_name: str | None = None,
exception: BaseException | None = None,
) -> None:
registrations = self.dag.pipeline_observers
registrations = self._resolve_pipeline_observers()
if not registrations:
return
ctx: Any
Expand Down Expand Up @@ -269,7 +286,7 @@ async def _dispatch_step_event(
completed_all_inputs: bool = True,
exception: BaseException | None = None,
) -> None:
registrations = node.observers
registrations = self._resolve_step_observers(node, step_name)
if not registrations:
return
ctx: Any
Expand Down Expand Up @@ -317,7 +334,7 @@ async def _dispatch_materialization_event(
materializer_name: str | None = None,
exception: BaseException | None = None,
) -> None:
registrations = node.observers
registrations = self._resolve_step_observers(node, step_name)
if not registrations:
return
ctx: Any
Expand Down
135 changes: 110 additions & 25 deletions synaflow/execution/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from dataclasses import dataclass
from typing import Any

from synaflow.core.constants import PIPELINE_SCOPE
from synaflow.core.definition import PipelineDef
from synaflow.core.naming import Scope
from synaflow.core.observers import Observer, ResolvedObserver


class PipelineRegistry(MutableMapping[str, Any]):
Expand All @@ -16,38 +19,46 @@ def __init__(
self._fallback_values = dict(fallback_values or {})
self._overrides: dict[str, Any] = {}

def __getitem__(self, key: str) -> Any:
self._validate_key(key)
if key in self._overrides:
return self._overrides[key]
if key in self._fallback_values:
return self._fallback_values[key]
raise KeyError(key)

def __setitem__(self, key: str, value: Any) -> None:
self._validate_key(key)
self._validate_value(key, value)
self._overrides[key] = value

def __delitem__(self, key: str) -> None:
self._validate_key(key)
if key not in self._overrides:
raise KeyError(key)
del self._overrides[key]
def __getitem__(self, key: str | Scope) -> Any:
normalized_key = self._normalize_key(key)
self._validate_key(normalized_key)
if normalized_key in self._overrides:
return self._overrides[normalized_key]
if normalized_key in self._fallback_values:
return self._fallback_values[normalized_key]
raise KeyError(normalized_key)

def __setitem__(self, key: str | Scope, value: Any) -> None:
normalized_key = self._normalize_key(key)
self._validate_key(normalized_key)
self._overrides[normalized_key] = self._normalize_value(normalized_key, value)

def __delitem__(self, key: str | Scope) -> None:
normalized_key = self._normalize_key(key)
self._validate_key(normalized_key)
if normalized_key not in self._overrides:
raise KeyError(normalized_key)
del self._overrides[normalized_key]

def __iter__(self) -> Iterator[str]:
return iter(sorted(self._contract_keys))

def __len__(self) -> int:
return len(self._contract_keys)

def resolve(self, key: str, default: Any = None) -> Any:
if key in self._overrides:
return self._overrides[key]
if key in self._fallback_values:
return self._fallback_values[key]
def resolve(self, key: str | Scope, default: Any = None) -> Any:
normalized_key = self._normalize_key(key)
if normalized_key in self._overrides:
return self._overrides[normalized_key]
if normalized_key in self._fallback_values:
return self._fallback_values[normalized_key]
return default

def _normalize_key(self, key: str | Scope) -> str:
if isinstance(key, Scope):
return str(key)
return key

def _validate_key(self, key: str) -> None:
if key not in self._contract_keys:
valid = ", ".join(sorted(self._contract_keys)) or "<none>"
Expand All @@ -56,6 +67,10 @@ def _validate_key(self, key: str) -> None:
def _validate_value(self, key: str, value: Any) -> None:
return None

def _normalize_value(self, key: str, value: Any) -> Any:
self._validate_value(key, value)
return value


class MaterializerRegistry(PipelineRegistry):
@classmethod
Expand All @@ -74,17 +89,62 @@ def _validate_value(self, key: str, value: Any) -> None:
raise TypeError(f"Materializer override for step '{key}' must be callable.")


class ObserverRegistry(PipelineRegistry):
@classmethod
def empty(cls, pipeline: PipelineDef) -> "ObserverRegistry":
contract_keys = _observer_contract_keys(pipeline)
return cls(
contract_keys=contract_keys,
fallback_values={key: [] for key in contract_keys},
)

@classmethod
def from_production(cls, pipeline: PipelineDef) -> "ObserverRegistry":
return cls(
contract_keys=_observer_contract_keys(pipeline),
fallback_values=_observer_fallback_values(pipeline),
)

def _normalize_value(self, key: str, value: Any) -> list[ResolvedObserver]:
if not isinstance(value, list):
raise TypeError(
f"Observer override for scope '{key}' must be a list of observers."
)

source = "pipeline" if key == PIPELINE_SCOPE else "step"
normalized: list[ResolvedObserver] = []
for item in value:
if isinstance(item, ResolvedObserver):
normalized.append(item)
elif isinstance(item, Observer):
normalized.append(ResolvedObserver(handler=item.handler, source=source))
elif callable(item):
normalized.append(ResolvedObserver(handler=item, source=source))
else:
raise TypeError(
f"Observer override for scope '{key}' must contain only callables or Observer registrations."
)
return normalized


@dataclass(frozen=True)
class ExecutionOverrides:
materializers: MaterializerRegistry
observers: ObserverRegistry

@classmethod
def empty(cls, pipeline: PipelineDef) -> "ExecutionOverrides":
return cls(materializers=MaterializerRegistry.empty(pipeline))
return cls(
materializers=MaterializerRegistry.empty(pipeline),
observers=ObserverRegistry.empty(pipeline),
)

@classmethod
def from_production(cls, pipeline: PipelineDef) -> "ExecutionOverrides":
return cls(materializers=MaterializerRegistry.from_production(pipeline))
return cls(
materializers=MaterializerRegistry.from_production(pipeline),
observers=ObserverRegistry.from_production(pipeline),
)


def _materializer_contract_keys(pipeline: PipelineDef) -> set[str]:
Expand All @@ -101,3 +161,28 @@ def _materializer_fallback_values(pipeline: PipelineDef) -> dict[str, Any]:
for step_name, node in pipeline.dag.steps.items()
if node.materializer is not None
}


def _observer_contract_keys(pipeline: PipelineDef) -> set[str]:
keys = set()
if pipeline.dag.pipeline_observers:
keys.add(PIPELINE_SCOPE)
keys.update(
step_name for step_name, node in pipeline.dag.steps.items() if node.observers
)
return keys


def _observer_fallback_values(
pipeline: PipelineDef,
) -> dict[str, list[ResolvedObserver]]:
values: dict[str, list[ResolvedObserver]] = {}
if pipeline.dag.pipeline_observers:
values[PIPELINE_SCOPE] = list(pipeline.dag.pipeline_observers)
for step_name, node in pipeline.dag.steps.items():
step_local = [
observer for observer in node.observers if observer.source == "step"
]
if step_local:
values[step_name] = step_local
return values
Loading
Loading