From 0778bd1ee5e7c87b4274e098262e5104f5d9614b Mon Sep 17 00:00:00 2001 From: Marcelo Valle Date: Sat, 20 Jun 2026 02:12:05 +0100 Subject: [PATCH 1/3] add runtime resource overrides --- synaflow/__init__.py | 3 +- synaflow/core/dag.py | 7 + synaflow/core/dag_builder.py | 39 ++++- synaflow/core/dag_dependencies.py | 13 +- synaflow/core/dag_steps.py | 3 +- synaflow/core/definition.py | 2 + synaflow/execution/__init__.py | 2 + synaflow/execution/async_engine/executor.py | 21 ++- synaflow/execution/overrides.py | 21 +++ synaflow/execution/sync_engine/executor.py | 23 ++- tests/core/test_dag_builder.py | 60 +++++++ tests/execution/test_execution_overrides.py | 180 ++++++++++++++++++++ 12 files changed, 366 insertions(+), 8 deletions(-) diff --git a/synaflow/__init__.py b/synaflow/__init__.py index ea4af72..5c36883 100644 --- a/synaflow/__init__.py +++ b/synaflow/__init__.py @@ -8,7 +8,7 @@ StepEvent, ) from .core.types import OnError, StepMode, StepParams, StepResult -from .execution import ExecutionOverrides +from .execution import ExecutionOverrides, ResourceRegistry from .execution.async_engine.executor import async_run from .execution.sync_engine.executor import run from .serializers import ( @@ -28,6 +28,7 @@ "run", "async_run", "ExecutionOverrides", + "ResourceRegistry", "OnError", "StepMode", "StepParams", diff --git a/synaflow/core/dag.py b/synaflow/core/dag.py index 2257920..8c02093 100644 --- a/synaflow/core/dag.py +++ b/synaflow/core/dag.py @@ -102,6 +102,7 @@ def _serialize_pipeline_observers(observers: list) -> list[dict]: class Dag: name: str = "" params: dict[str, Any] = field(default_factory=dict) + resources: dict[str, Any] = field(default_factory=dict) steps: dict[str, DagNode] = field(default_factory=dict) requires_sync_runner: bool = False requires_async_runner: bool = False @@ -132,6 +133,8 @@ def values(self): def get(self, key, default=None): if key in self.steps: return self.steps[key] + if key in self.resources: + return DagNode(output=self.resources[key]) if key in self.params: return DagNode(output=self.params[key]) return default @@ -147,6 +150,10 @@ def to_dict(self) -> dict: name: node.to_serializable() for name, node in self.steps.items() }, } + if self.resources: + result["resources"] = { + k: get_type_name(v) for k, v in self.resources.items() + } if self.error_materializer_factory is not None: result["error_materializer"] = self.error_materializer_factory.__name__ if self.pipeline_observers: diff --git a/synaflow/core/dag_builder.py b/synaflow/core/dag_builder.py index 86eb0a2..5997df6 100644 --- a/synaflow/core/dag_builder.py +++ b/synaflow/core/dag_builder.py @@ -39,7 +39,7 @@ ) from synaflow.core.dag import Dag, DagNode -from synaflow.core.dag_dependencies import initialize_parameters +from synaflow.core.dag_dependencies import initialize_parameters, initialize_resources from synaflow.core.dag_expansion import expand_macros from synaflow.core.dag_steps import ( validate_and_compile_step, @@ -169,6 +169,26 @@ def _validate_declared_step_names(steps: list[Any], pipeline_name: str) -> None: validate_unique_step_name(step.name, {}, pipeline_name) +def _validate_resource_names( + resources: dict[str, Any], + params: type[NamedTuple], + expanded_steps: list[Any], + pipeline_name: str, +) -> None: + param_fields = set(getattr(params, "_fields", [])) + step_names = {step.name for step in expanded_steps} + + for resource_name in resources: + if resource_name in param_fields: + raise ValueError( + f"Pipeline '{pipeline_name}': resource '{resource_name}' collides with a params field." + ) + if resource_name in step_names: + raise ValueError( + f"Pipeline '{pipeline_name}': resource '{resource_name}' collides with a step name." + ) + + def _resolve_pipeline_observers( pipeline_observers: list[Observer], ) -> list[ResolvedObserver]: @@ -328,10 +348,13 @@ def _compile_steps( expanded_steps: list[Any], pipeline_name: str, params: type[NamedTuple], + resources: dict[str, Any], pipeline_observers: list[ResolvedObserver], ) -> tuple[dict[str, DagNode], dict[str, DagNode]]: dag: dict[str, DagNode] = {} produced = initialize_parameters(params) + produced.update(initialize_resources(resources)) + resource_nodes = initialize_resources(resources) for step in expanded_steps: validate_step_is_callable(step, pipeline_name) @@ -340,6 +363,7 @@ def _compile_steps( compiled_step = validate_and_compile_step( step, produced, + resource_nodes, pipeline_name, observers=_resolve_step_observers(pipeline_observers, step.observers), ) @@ -353,12 +377,18 @@ def _finalize_dag( pipeline_name: str, dag: dict[str, DagNode], produced: dict[str, DagNode], + resource_names: set[str], error_materializer_factory: Any, pipeline_observers: list[ResolvedObserver], ) -> Dag: dag_obj = Dag(name=pipeline_name) dag_obj.params = { - name: info.output for name, info in produced.items() if name not in dag + name: info.output + for name, info in produced.items() + if name not in dag and name not in resource_names + } + dag_obj.resources = { + name: info.output for name, info in produced.items() if name in resource_names } dag_obj.steps = dag dag_obj.error_materializer_factory = error_materializer_factory @@ -370,6 +400,7 @@ def build_dag( pipeline_name: str, params: type[NamedTuple], steps: list[Any], + resources: dict[str, Any] | None = None, memory_materializer_factory: Any = None, is_default_factory: bool = False, error_materializer_factory: Any = None, @@ -382,10 +413,13 @@ def build_dag( _validate_params_is_namedtuple(params, pipeline_name) pipeline_obs_resolved = _resolve_pipeline_observers(pipeline_observers or []) expanded_steps = _expand_and_validate_steps(steps, pipeline_name) + resources = resources or {} + _validate_resource_names(resources, params, expanded_steps, pipeline_name) dag, produced = _compile_steps( expanded_steps, pipeline_name, params, + resources, pipeline_obs_resolved, ) _compute_materialized_deps(dag) @@ -393,6 +427,7 @@ def build_dag( pipeline_name, dag, produced, + set(resources), error_materializer_factory, pipeline_obs_resolved, ) diff --git a/synaflow/core/dag_dependencies.py b/synaflow/core/dag_dependencies.py index 406971e..a01f137 100644 --- a/synaflow/core/dag_dependencies.py +++ b/synaflow/core/dag_dependencies.py @@ -31,6 +31,14 @@ def initialize_parameters(params: type[NamedTuple]) -> dict[str, DagNode]: return produced +def initialize_resources(resources: dict[str, Any]) -> dict[str, DagNode]: + produced: dict[str, DagNode] = {} + for name, resource in resources.items(): + resource_type = resource if isinstance(resource, type) else type(resource) + produced[name] = DagNode(output=resource_type) + return produced + + def get_safe_type_hints(fn: Any) -> dict[str, Any]: try: return typing.get_type_hints(fn, include_extras=True) @@ -52,6 +60,7 @@ def validate_and_resolve_dependencies( sig: inspect.Signature, hints: dict[str, Any], produced: dict[str, DagNode], + resources: dict[str, DagNode], pipeline_name: str, ) -> tuple[dict[str, Any], dict[str, str]]: deps: dict[str, Any] = {} @@ -62,7 +71,9 @@ def validate_and_resolve_dependencies( if consumer_type is inspect.Parameter.empty: consumer_type = None - if param_name in produced: + if param_name in resources: + producer_name = param_name + elif param_name in produced: producer_name = param_name else: param_base = get_base_dataset_name(param_name) diff --git a/synaflow/core/dag_steps.py b/synaflow/core/dag_steps.py index fe27a74..980806c 100644 --- a/synaflow/core/dag_steps.py +++ b/synaflow/core/dag_steps.py @@ -42,6 +42,7 @@ def validate_unique_step_name( def validate_and_compile_step( step: Step, produced: dict[str, DagNode], + resources: dict[str, DagNode], pipeline_name: str, observers: list | None = None, ) -> DagNode: @@ -51,7 +52,7 @@ def validate_and_compile_step( _validate_max_in_flight(step, pipeline_name) deps, dataset_param_names = validate_and_resolve_dependencies( - step, sig, hints, produced, pipeline_name + step, sig, hints, produced, resources, pipeline_name ) mode, each_mode_deps = resolve_step_mode(step, deps, produced, pipeline_name) diff --git a/synaflow/core/definition.py b/synaflow/core/definition.py index 63606a7..959d05a 100644 --- a/synaflow/core/definition.py +++ b/synaflow/core/definition.py @@ -43,6 +43,7 @@ class PipelineDef: name: str params: Any steps: list[Step | IncludeStep] + resources: dict[str, Any] = field(default_factory=dict) exports: str | None = None materializer: Callable | None = None error_materializer: Callable | None = None @@ -58,6 +59,7 @@ def __post_init__(self) -> None: self.name, self.params, self.steps, + self.resources, self.materializer, is_default_factory=(self.materializer is None), error_materializer_factory=self.error_materializer, diff --git a/synaflow/execution/__init__.py b/synaflow/execution/__init__.py index 0d70d2c..527ab00 100644 --- a/synaflow/execution/__init__.py +++ b/synaflow/execution/__init__.py @@ -3,6 +3,7 @@ MaterializerRegistry, ObserverRegistry, PipelineRegistry, + ResourceRegistry, ) __all__ = [ @@ -10,4 +11,5 @@ "MaterializerRegistry", "ObserverRegistry", "PipelineRegistry", + "ResourceRegistry", ] diff --git a/synaflow/execution/async_engine/executor.py b/synaflow/execution/async_engine/executor.py index a872af8..ce3cac4 100644 --- a/synaflow/execution/async_engine/executor.py +++ b/synaflow/execution/async_engine/executor.py @@ -374,10 +374,29 @@ async def _dispatch_materialization_event( # Execution # ------------------------------------------------------------------ - async def execute(self, params: Any) -> None: + def _seed_runtime_inputs(self, params: Any) -> None: + if self.dag.resources: + if self._overrides is None: + resource_names = ", ".join(sorted(self.dag.resources)) + raise ValueError( + f"Pipeline '{self.dag.name}' requires runtime resources: {resource_names}." + ) + for resource_name in self.dag.resources: + try: + self.outputs[resource_name] = self._overrides.resources[ + resource_name + ] + except KeyError as exc: + raise ValueError( + f"Pipeline '{self.dag.name}' requires resource '{resource_name}' at runtime." + ) from exc + for field, value in params._asdict().items(): self.outputs[field] = value + async def execute(self, params: Any) -> None: + self._seed_runtime_inputs(params) + await self._dispatch_pipeline_event(PipelineEvent.STARTED) try: for level in self.dag.get_execution_levels(): diff --git a/synaflow/execution/overrides.py b/synaflow/execution/overrides.py index 2d0f7c3..f82e53e 100644 --- a/synaflow/execution/overrides.py +++ b/synaflow/execution/overrides.py @@ -127,16 +127,32 @@ def _normalize_value(self, key: str, value: Any) -> list[ResolvedObserver]: return normalized +class ResourceRegistry(PipelineRegistry): + @classmethod + def empty(cls, pipeline: PipelineDef) -> "ResourceRegistry": + return cls(contract_keys=_resource_contract_keys(pipeline)) + + @classmethod + def from_production(cls, pipeline: PipelineDef) -> "ResourceRegistry": + return cls.empty(pipeline) + + def _validate_value(self, key: str, value: Any) -> None: + if value is None: + raise TypeError(f"Resource override for key '{key}' cannot be None.") + + @dataclass(frozen=True) class ExecutionOverrides: materializers: MaterializerRegistry observers: ObserverRegistry + resources: ResourceRegistry @classmethod def empty(cls, pipeline: PipelineDef) -> "ExecutionOverrides": return cls( materializers=MaterializerRegistry.empty(pipeline), observers=ObserverRegistry.empty(pipeline), + resources=ResourceRegistry.empty(pipeline), ) @classmethod @@ -144,6 +160,7 @@ def from_production(cls, pipeline: PipelineDef) -> "ExecutionOverrides": return cls( materializers=MaterializerRegistry.from_production(pipeline), observers=ObserverRegistry.from_production(pipeline), + resources=ResourceRegistry.from_production(pipeline), ) @@ -186,3 +203,7 @@ def _observer_fallback_values( if step_local: values[step_name] = step_local return values + + +def _resource_contract_keys(pipeline: PipelineDef) -> set[str]: + return set(pipeline.dag.resources) diff --git a/synaflow/execution/sync_engine/executor.py b/synaflow/execution/sync_engine/executor.py index b1c9791..8ec8fdf 100644 --- a/synaflow/execution/sync_engine/executor.py +++ b/synaflow/execution/sync_engine/executor.py @@ -264,14 +264,33 @@ def _dispatch_materialization_event( # Execution # ------------------------------------------------------------------ + def _seed_runtime_inputs(self, params: Any) -> None: + if self.dag.resources: + if self._overrides is None: + resource_names = ", ".join(sorted(self.dag.resources)) + raise ValueError( + f"Pipeline '{self.dag.name}' requires runtime resources: {resource_names}." + ) + for resource_name in self.dag.resources: + try: + self.outputs[resource_name] = self._overrides.resources[ + resource_name + ] + except KeyError as exc: + raise ValueError( + f"Pipeline '{self.dag.name}' requires resource '{resource_name}' at runtime." + ) from exc + + for field, value in params._asdict().items(): + self.outputs[field] = value + def _resolve_materializer(self, step_name: str, node: Any) -> Any: if self._overrides is None: return node.materializer return self._overrides.materializers.resolve(step_name, node.materializer) def execute(self, params: Any) -> None: - for field, value in params._asdict().items(): - self.outputs[field] = value + self._seed_runtime_inputs(params) self._dispatch_pipeline_event(PipelineEvent.STARTED) try: diff --git a/tests/core/test_dag_builder.py b/tests/core/test_dag_builder.py index 3021ec7..eda6e44 100644 --- a/tests/core/test_dag_builder.py +++ b/tests/core/test_dag_builder.py @@ -77,6 +77,66 @@ def fn(limit: int) -> int: pipeline(name="t", params=P, steps=[step("s1", fn=fn)]) +def test_given_dependency_on_declared_resource_when_constructed_then_passes(): + class DB: + pass + + class P(NamedTuple): + limit: int = 10 + + def fn(db: DB, limit: int) -> int: + return limit + + p = pipeline( + name="t", + params=P, + resources={"db": DB}, + steps=[step("s1", fn=fn)], + ) + + assert p.dag.resources == {"db": DB} + assert p.dag.steps["s1"].deps == {"db": DB, "limit": int} + assert p.to_dict()["resources"] == {"db": "DB"} + + +def test_given_resource_name_colliding_with_params_field_when_constructed_then_raises(): + class DB: + pass + + class P(NamedTuple): + db: int = 10 + + def fn(db: DB) -> None: + pass + + with pytest.raises(ValueError, match="collides with a params field"): + pipeline( + name="t", + params=P, + resources={"db": DB}, + steps=[step("s1", fn=fn)], + ) + + +def test_given_resource_name_colliding_with_step_name_when_constructed_then_raises(): + class DB: + pass + + class P(NamedTuple): + limit: int = 10 + + def fn(limit: int) -> int: + return limit + + with pytest.raises(ValueError, match="collides with a step name"): + pipeline( + name="t", + params=P, + resources={"db": DB}, + steps=[step("db", fn=fn)], + ) + + def test_given_duplicate_step_name_when_constructed_then_raises(): class P(NamedTuple): pass diff --git a/tests/execution/test_execution_overrides.py b/tests/execution/test_execution_overrides.py index 5d66805..2b4a107 100644 --- a/tests/execution/test_execution_overrides.py +++ b/tests/execution/test_execution_overrides.py @@ -7,6 +7,7 @@ Observer, PIPELINE_SCOPE, PipelineEvent, + ResourceRegistry, Scope, StepEvent, async_run, @@ -492,3 +493,182 @@ def consume(incl: int) -> None: ("incl__prepare", "StepStartedContext"), ("incl__prepare", "StepCompletedContext"), ] + + +def test_given_resource_registry_empty_when_missing_required_resource_then_run_raises( + run_pipeline, +): + class DB: + pass + + class Params(NamedTuple): + value: int = 1 + + seen = [] + + def use(db: DB, value: int) -> None: + seen.append((db, value)) + + p = pipeline( + name="missing_resource", + params=Params, + resources={"db": DB}, + steps=[step("use", fn=use)], + ) + + overrides = ExecutionOverrides.empty(p) + + with pytest.raises(ValueError, match="requires resource 'db'"): + run_pipeline(p, Params(), overrides=overrides) + + assert seen == [] + + +def test_given_resource_registry_without_overrides_when_pipeline_requires_resources_then_run_raises( + run_pipeline, +): + class DB: + pass + + class Params(NamedTuple): + value: int = 1 + + def use(db: DB, value: int) -> None: + return None + + p = pipeline( + name="missing_resource_registry", + params=Params, + resources={"db": DB}, + steps=[step("use", fn=use)], + ) + + with pytest.raises(ValueError, match="requires runtime resources: db"): + run_pipeline(p, Params()) + + +def test_given_resource_override_when_sync_run_then_resource_is_injected(run_pipeline): + class DB: + pass + + class Params(NamedTuple): + value: int = 3 + + seen = [] + + def use(db: DB, value: int) -> None: + seen.append((db, value)) + + p = pipeline( + name="sync_resource_override", + params=Params, + resources={"db": DB}, + steps=[step("use", fn=use)], + ) + + overrides = ExecutionOverrides.empty(p) + db = DB() + overrides.resources["db"] = db + + run_pipeline(p, Params(), overrides=overrides) + + assert seen == [(db, 3)] + + +async def test_given_resource_override_when_async_run_then_resource_is_injected(): + class DB: + pass + + class Params(NamedTuple): + value: int = 3 + + seen = [] + + async def use(db: DB, value: int) -> None: + seen.append((db, value)) + + p = pipeline( + name="async_resource_override", + params=Params, + resources={"db": DB}, + steps=[step("use", fn=use)], + ) + + overrides = ExecutionOverrides.empty(p) + db = DB() + overrides.resources["db"] = db + + await async_run(p, Params(), overrides=overrides) + + assert seen == [(db, 3)] + + +def test_given_execution_overrides_from_production_when_resources_requested_then_registry_is_empty_but_keyed(): + class DB: + pass + + class Params(NamedTuple): + value: int = 1 + + def use(db: DB, value: int) -> None: + return None + + p = pipeline( + name="resource_contract", + params=Params, + resources={"db": DB}, + steps=[step("use", fn=use)], + ) + + overrides = ExecutionOverrides.from_production(p) + + assert isinstance(overrides.resources, ResourceRegistry) + assert list(overrides.resources) == ["db"] + with pytest.raises(KeyError): + _ = overrides.resources["db"] + + +def test_given_unknown_resource_override_key_when_assigned_then_raises(): + class DB: + pass + + class Params(NamedTuple): + value: int = 1 + + def use(db: DB, value: int) -> None: + return None + + p = pipeline( + name="invalid_resource_key", + params=Params, + resources={"db": DB}, + steps=[step("use", fn=use)], + ) + + overrides = ExecutionOverrides.empty(p) + + with pytest.raises(KeyError, match="Unknown override key 'missing'"): + overrides.resources["missing"] = DB() + + +def test_given_none_resource_override_value_when_assigned_then_raises(): + class DB: + pass + + class Params(NamedTuple): + value: int = 1 + + def use(db: DB, value: int) -> None: + return None + + p = pipeline( + name="invalid_resource_value", + params=Params, + resources={"db": DB}, + steps=[step("use", fn=use)], + ) + + overrides = ExecutionOverrides.empty(p) + + with pytest.raises(TypeError, match="cannot be None"): + overrides.resources["db"] = None From cd790278b0773ede0c131e0b3b6566f346ab6177 Mon Sep 17 00:00:00 2001 From: Marcelo Valle Date: Sat, 20 Jun 2026 02:13:21 +0100 Subject: [PATCH 2/3] document runtime resource overrides --- README.md | 3 +++ docs/user_docs/core-concepts/build-vs-run.md | 13 ++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5374c45..6841f13 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ from synaflow import ExecutionOverrides, Observer, PIPELINE_SCOPE, Scope overrides = ExecutionOverrides.empty(p) sub = Scope("payments") +overrides.resources["db"] = FakeDatabase() overrides.observers[PIPELINE_SCOPE] = [Observer(noop_metrics)] overrides.observers[sub.scope("validate")] = [Observer(test_recorder)] overrides.materializers[sub.scope("normalize")] = list @@ -95,6 +96,8 @@ overrides.materializers[sub.scope("normalize")] = list For included sub-pipelines, `Scope(...)` is the public helper for addressing compiled step keys without hardcoding `"payments__validate"` by hand. +Declared `resources={...}` are runtime-only and must be provided through +`ExecutionOverrides.resources`. ### Build your own runner diff --git a/docs/user_docs/core-concepts/build-vs-run.md b/docs/user_docs/core-concepts/build-vs-run.md index 310bb84..827b418 100644 --- a/docs/user_docs/core-concepts/build-vs-run.md +++ b/docs/user_docs/core-concepts/build-vs-run.md @@ -68,9 +68,9 @@ resolved at build time and frozen in the JSON or `Dag`. Runners don't re-infer semantics; they execute the contract. `ExecutionOverrides` fits inside that boundary: it can swap the concrete -runtime callable for a compiled key such as a materializer or observer scope, -but it does not change graph structure, dependency resolution, or eager-vs-lazy -planning. +runtime callable for a compiled key such as a materializer, observer scope, or +declared runtime resource, but it does not change graph structure, dependency +resolution, or eager-vs-lazy planning. For nested pipelines, the public key helper is `Scope`, not manual string concatenation: @@ -81,6 +81,7 @@ from synaflow import ExecutionOverrides, Observer, PIPELINE_SCOPE, Scope overrides = ExecutionOverrides.empty(p) sub = Scope("incl") +overrides.resources["db"] = FakeDatabase() overrides.observers[PIPELINE_SCOPE] = [Observer(noop)] overrides.observers[sub.scope("validate")] = [Observer(spy)] overrides.materializers[sub.scope("prepare")] = tuple @@ -90,6 +91,11 @@ The executor never understands sub-pipelines directly. `Scope` resolves to the compiled DAG step key before execution starts, so runtime still operates on the same flat compiled contract. +For resources, the key space is explicit in the compiled pipeline contract via +`pipeline(resources={...})`. Unlike materializers and observers, resources are +runtime-only: if a pipeline declares one, `run()` / `async_run()` must receive +it via `ExecutionOverrides.resources`, or execution fails loudly. + ### 2. Write your own runner The `Dag` object is self-contained. Anyone can write a runner: @@ -141,6 +147,7 @@ Every domain concern has a symmetric representation in both phases: | Mode resolution | Resolved at build time → `node.mode` | Executor reads `node.mode`, never re-infers | | Materialization | Resolved at build time → `node.materializer`, `materialized_deps`, `Dag.needs_materialize(...)` | Executor calls the resolved callable and follows the compiled plan | | Observers | Normalized at build time → `node.observers` | Executor dispatches events | +| Resources | Declared at build time → `dag.resources` | Executor requires concrete runtime values via `ExecutionOverrides.resources` | This symmetry means sync and async executors can be completely different implementations (one uses generators, the other uses `asyncio.Queue`) but From 9a06fe20798154ecdaa3836b1c25f0c9c78e18f5 Mon Sep 17 00:00:00 2001 From: Marcelo Valle Date: Sat, 20 Jun 2026 08:33:24 +0100 Subject: [PATCH 3/3] inherit sub-pipeline resources in contract --- synaflow/core/dag_builder.py | 57 +++++++++++- synaflow/core/dag_expansion.py | 15 ++++ tests/core/test_dag_builder.py | 98 +++++++++++++++++++++ tests/execution/test_execution_overrides.py | 50 +++++++++++ 4 files changed, 216 insertions(+), 4 deletions(-) diff --git a/synaflow/core/dag_builder.py b/synaflow/core/dag_builder.py index 5997df6..7cea48b 100644 --- a/synaflow/core/dag_builder.py +++ b/synaflow/core/dag_builder.py @@ -40,6 +40,7 @@ from synaflow.core.dag import Dag, DagNode from synaflow.core.dag_dependencies import initialize_parameters, initialize_resources +from synaflow.core.definition import IncludeStep from synaflow.core.dag_expansion import expand_macros from synaflow.core.dag_steps import ( validate_and_compile_step, @@ -189,6 +190,49 @@ def _validate_resource_names( ) +def _merge_resources( + merged: dict[str, Any], + incoming: dict[str, Any], + pipeline_name: str, +) -> None: + for resource_name, resource in incoming.items(): + if resource_name in merged and merged[resource_name] is not resource: + raise ValueError( + f"Pipeline '{pipeline_name}': resource '{resource_name}' is declared multiple times with different instances/factories." + ) + merged.setdefault(resource_name, resource) + + +def _collect_pipeline_resources( + pipeline_name: str, + steps: list[Any], + resources: dict[str, Any], + include_chain: tuple[str, ...] = (), +) -> dict[str, Any]: + merged: dict[str, Any] = {} + _merge_resources(merged, resources, pipeline_name) + + for step in steps: + if not isinstance(step, IncludeStep): + continue + + sub_pipeline = step.pipeline + if sub_pipeline.name in include_chain: + raise ValueError( + f"Infinite cycle detected: Pipeline '{sub_pipeline.name}' is already in the inclusion chain '{'.'.join(include_chain)}'" + ) + + sub_resources = _collect_pipeline_resources( + pipeline_name, + sub_pipeline.steps, + sub_pipeline.resources, + (*include_chain, sub_pipeline.name), + ) + _merge_resources(merged, sub_resources, pipeline_name) + + return merged + + def _resolve_pipeline_observers( pipeline_observers: list[Observer], ) -> list[ResolvedObserver]: @@ -413,13 +457,18 @@ def build_dag( _validate_params_is_namedtuple(params, pipeline_name) pipeline_obs_resolved = _resolve_pipeline_observers(pipeline_observers or []) expanded_steps = _expand_and_validate_steps(steps, pipeline_name) - resources = resources or {} - _validate_resource_names(resources, params, expanded_steps, pipeline_name) + effective_resources = _collect_pipeline_resources( + pipeline_name, + steps, + resources or {}, + include_chain=(pipeline_name,), + ) + _validate_resource_names(effective_resources, params, expanded_steps, pipeline_name) dag, produced = _compile_steps( expanded_steps, pipeline_name, params, - resources, + effective_resources, pipeline_obs_resolved, ) _compute_materialized_deps(dag) @@ -427,7 +476,7 @@ def build_dag( pipeline_name, dag, produced, - set(resources), + set(effective_resources), error_materializer_factory, pipeline_obs_resolved, ) diff --git a/synaflow/core/dag_expansion.py b/synaflow/core/dag_expansion.py index 5ed63d8..538119f 100644 --- a/synaflow/core/dag_expansion.py +++ b/synaflow/core/dag_expansion.py @@ -92,6 +92,10 @@ def _extract_sub_pipeline_param_fields(params: Any) -> list[str]: return [] +def _extract_sub_pipeline_resource_fields(resources: dict[str, Any]) -> list[str]: + return list(resources) + + def _build_expanded_step_name(prefix: str, sub_step: Step, exported_name: str) -> str: if sub_step.name == exported_name: return prefix @@ -114,6 +118,7 @@ def _expand_sub_pipeline_steps( include_step: IncludeStep, adapter_name: str, sub_pipeline_param_fields: list[str], + sub_pipeline_resource_fields: list[str], new_parent_chain: str | None, ) -> list[Step]: prefix = include_step.name @@ -131,6 +136,7 @@ def _expand_sub_pipeline_steps( prefix, adapter_name, sub_pipeline_param_fields, + sub_pipeline_resource_fields, sub_pipeline.params, ) materializer, error_materializer, observers = _resolve_sub_step_overrides( @@ -172,10 +178,14 @@ def _expand_include( include_step, current_pipeline_name, parent_chain ) sub_pipeline_param_fields = _extract_sub_pipeline_param_fields(sub_pipeline.params) + sub_pipeline_resource_fields = _extract_sub_pipeline_resource_fields( + sub_pipeline.resources + ) expanded_steps = _expand_sub_pipeline_steps( include_step, adapter_name, sub_pipeline_param_fields, + sub_pipeline_resource_fields, new_parent_chain, ) return [adapter_step, *expanded_steps] @@ -186,11 +196,14 @@ def _build_argument_mapping( prefix: str, adapter_name: str, sub_pipeline_param_fields: list[str], + sub_pipeline_resource_fields: list[str], ) -> dict[str, str]: arg_mapping: dict[str, str] = {} for param_name in signature.parameters: if param_name in sub_pipeline_param_fields: arg_mapping[param_name] = adapter_name + elif param_name in sub_pipeline_resource_fields: + arg_mapping[param_name] = param_name else: arg_mapping[param_name] = f"{prefix}__{param_name}" return arg_mapping @@ -241,6 +254,7 @@ def _wrap_sub_step_fn( prefix: str, adapter_name: str, sub_pipeline_param_fields: list[str], + sub_pipeline_resource_fields: list[str], sub_pipeline_params_class: Any, ) -> Any: signature = inspect.signature(original_fn) @@ -249,6 +263,7 @@ def _wrap_sub_step_fn( prefix, adapter_name, sub_pipeline_param_fields, + sub_pipeline_resource_fields, ) if inspect.iscoroutinefunction(original_fn): diff --git a/tests/core/test_dag_builder.py b/tests/core/test_dag_builder.py index eda6e44..2ab6a36 100644 --- a/tests/core/test_dag_builder.py +++ b/tests/core/test_dag_builder.py @@ -3,6 +3,7 @@ import pytest from synaflow import StepMode, pipeline, step +from synaflow.core.definition import include from synaflow.core.types import OnError @@ -137,6 +138,103 @@ def fn(limit: int) -> int: ) +def test_given_sub_pipeline_resource_when_constructed_then_resource_is_inherited_into_parent_contract(): + class DB: + pass + + class SubParams(NamedTuple): + value: int + + class Params(NamedTuple): + value: int = 10 + + def use(db: DB, value: int) -> int: + return value + + sub = pipeline( + name="sub", + params=SubParams, + resources={"db": DB}, + steps=[step("use", fn=use)], + exports="use", + ) + + def adapt(value: int) -> SubParams: + return SubParams(value=value) + + p = pipeline( + name="parent", + params=Params, + steps=[include("incl", pipeline=sub, fn=adapt)], + ) + + assert p.dag.resources == {"db": DB} + assert p.dag.steps["incl"].deps == {"db": DB, "incl__adapter": SubParams} + + +def test_given_parent_and_sub_pipeline_same_resource_instance_when_constructed_then_builds(): + shared = object() + + class SubParams(NamedTuple): + value: int + + class Params(NamedTuple): + value: int = 10 + + def use(db: object, value: int) -> int: + return value + + sub = pipeline( + name="sub", + params=SubParams, + resources={"db": shared}, + steps=[step("use", fn=use)], + exports="use", + ) + + def adapt(value: int) -> SubParams: + return SubParams(value=value) + + p = pipeline( + name="parent", + params=Params, + resources={"db": shared}, + steps=[include("incl", pipeline=sub, fn=adapt)], + ) + + assert p.dag.resources["db"] is object + + +def test_given_parent_and_sub_pipeline_different_resource_instances_with_same_name_when_constructed_then_raises(): + class SubParams(NamedTuple): + value: int + + class Params(NamedTuple): + value: int = 10 + + def use(db: object, value: int) -> int: + return value + + sub = pipeline( + name="sub", + params=SubParams, + resources={"db": object()}, + steps=[step("use", fn=use)], + exports="use", + ) + + def adapt(value: int) -> SubParams: + return SubParams(value=value) + + with pytest.raises(ValueError, match="resource 'db' is declared multiple times"): + pipeline( + name="parent", + params=Params, + resources={"db": object()}, + steps=[include("incl", pipeline=sub, fn=adapt)], + ) + + def test_given_duplicate_step_name_when_constructed_then_raises(): class P(NamedTuple): pass diff --git a/tests/execution/test_execution_overrides.py b/tests/execution/test_execution_overrides.py index 2b4a107..1752901 100644 --- a/tests/execution/test_execution_overrides.py +++ b/tests/execution/test_execution_overrides.py @@ -672,3 +672,53 @@ def use(db: DB, value: int) -> None: with pytest.raises(TypeError, match="cannot be None"): overrides.resources["db"] = None + + +def test_given_sub_pipeline_resource_when_overridden_then_resource_is_injected_into_included_step( + run_pipeline, +): + class DB: + pass + + class SubParams(NamedTuple): + value: int + + class Params(NamedTuple): + value: int = 3 + + seen = [] + + def use(db: DB, value: int) -> int: + seen.append((db, value)) + return value + + sub = pipeline( + name="sub", + params=SubParams, + resources={"db": DB}, + steps=[step("use", fn=use)], + exports="use", + ) + + def adapt(value: int) -> SubParams: + return SubParams(value=value) + + def consume(incl: int) -> None: + return None + + p = pipeline( + name="parent", + params=Params, + steps=[ + include("incl", pipeline=sub, fn=adapt), + step("consume", fn=consume), + ], + ) + + overrides = ExecutionOverrides.empty(p) + db = DB() + overrides.resources["db"] = db + + run_pipeline(p, Params(), overrides=overrides) + + assert seen == [(db, 3)]