Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions plugboard-schemas/plugboard_schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from importlib.metadata import version

from ._common import PlugboardBaseModel
from ._graph import simple_cycles
from ._validation import (
validate_all_inputs_connected,
validate_input_events,
validate_no_unresolved_cycles,
validate_process,
)
from .component import ComponentArgsDict, ComponentArgsSpec, ComponentSpec, Resource
from .config import ConfigSpec, ProcessConfigSpec
from .connector import (
Expand Down Expand Up @@ -85,4 +92,9 @@
"TuneArgsDict",
"TuneArgsSpec",
"TuneSpec",
"simple_cycles",
"validate_all_inputs_connected",
"validate_input_events",
"validate_no_unresolved_cycles",
"validate_process",
]
126 changes: 126 additions & 0 deletions plugboard-schemas/plugboard_schemas/_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Graph algorithms for topology validation.
Implements Johnson's algorithm for finding all simple cycles in a directed graph,
along with helper functions for strongly connected components.
References:
Donald B Johnson. "Finding all the elementary circuits of a directed graph."
SIAM Journal on Computing. 1975.
"""

from collections import defaultdict
from collections.abc import Generator


def simple_cycles(graph: dict[str, set[str]]) -> Generator[list[str], None, None]:
"""Find all simple cycles in a directed graph using Johnson's algorithm.
Args:
graph: A dictionary mapping each vertex to a set of its neighbours.
Yields:
Each elementary cycle as a list of vertices.
"""
graph = {v: set(nbrs) for v, nbrs in graph.items()}
sccs = _strongly_connected_components(graph)
while sccs:
scc = sccs.pop()
startnode = scc.pop()
path = [startnode]
blocked: set[str] = set()
closed: set[str] = set()
blocked.add(startnode)
B: dict[str, set[str]] = defaultdict(set)
stack: list[tuple[str, list[str]]] = [(startnode, list(graph[startnode]))]
while stack:
thisnode, nbrs = stack[-1]
if nbrs:
nextnode = nbrs.pop()
if nextnode == startnode:
yield path[:]
closed.update(path)
elif nextnode not in blocked:
path.append(nextnode)
stack.append((nextnode, list(graph[nextnode])))
closed.discard(nextnode)
blocked.add(nextnode)
continue
if not nbrs:
if thisnode in closed:
_unblock(thisnode, blocked, B)
else:
for nbr in graph[thisnode]:
if thisnode not in B[nbr]:
B[nbr].add(thisnode)
stack.pop()
path.pop()
_remove_node(graph, startnode)
H = _subgraph(graph, set(scc))
sccs.extend(_strongly_connected_components(H))


def _unblock(thisnode: str, blocked: set[str], B: dict[str, set[str]]) -> None:
"""Unblock a node and recursively unblock nodes in its B set."""
stack = {thisnode}
while stack:
node = stack.pop()
if node in blocked:
blocked.remove(node)
stack.update(B[node])
B[node].clear()


def _strongly_connected_components(graph: dict[str, set[str]]) -> list[set[str]]:
"""Find all strongly connected components using Tarjan's algorithm.
Args:
graph: A dictionary mapping each vertex to a set of its neighbours.
Returns:
A list of sets, each containing the vertices of a strongly connected component.
"""
index_counter = [0]
stack: list[str] = []
lowlink: dict[str, int] = {}
index: dict[str, int] = {}
result: list[set[str]] = []

def _strong_connect(node: str) -> None:
index[node] = index_counter[0]
lowlink[node] = index_counter[0]
index_counter[0] += 1
stack.append(node)

for successor in graph.get(node, set()):
if successor not in index:
_strong_connect(successor)
lowlink[node] = min(lowlink[node], lowlink[successor])
elif successor in stack:
lowlink[node] = min(lowlink[node], index[successor])

if lowlink[node] == index[node]:
connected_component: set[str] = set()
while True:
successor = stack.pop()
connected_component.add(successor)
if successor == node:
break
result.append(connected_component)

for node in graph:
if node not in index:
_strong_connect(node)

return result


def _remove_node(graph: dict[str, set[str]], target: str) -> None:
"""Remove a node and all its edges from the graph."""
del graph[target]
for nbrs in graph.values():
nbrs.discard(target)


def _subgraph(graph: dict[str, set[str]], vertices: set[str]) -> dict[str, set[str]]:
"""Get the subgraph induced by a set of vertices."""
return {v: graph[v] & vertices for v in vertices}
197 changes: 197 additions & 0 deletions plugboard-schemas/plugboard_schemas/_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Validation utilities for process topology.

Provides functions to validate process topology including:
- Checking that all component inputs are connected
- Checking that input events have matching output event producers
- Checking for circular connections that require initial values

All validators accept the output of ``process.dict()`` or the relevant
sub-structures thereof.
"""

from __future__ import annotations

from collections import defaultdict
import typing as _t

from ._graph import simple_cycles


def _build_component_graph(
connectors: dict[str, dict[str, _t.Any]],
) -> dict[str, set[str]]:
"""Build a directed graph of component connections from connector dicts.

Args:
connectors: Dictionary mapping connector IDs to connector dicts,
as returned by ``process.dict()["connectors"]``.

Returns:
A dictionary mapping source component names to sets of target component names.
"""
graph: dict[str, set[str]] = defaultdict(set)
for conn_info in connectors.values():
spec = conn_info["spec"]
source_entity = spec["source"]["entity"]
target_entity = spec["target"]["entity"]
if source_entity != target_entity:
graph[source_entity].add(target_entity)
if target_entity not in graph:
graph[target_entity] = set()
return dict(graph)


def _get_edges_in_cycle(
cycle: list[str],
connectors: dict[str, dict[str, _t.Any]],
) -> list[dict[str, _t.Any]]:
"""Get all connector spec dicts that form edges within a cycle.

Args:
cycle: List of component names forming a cycle.
connectors: Dictionary mapping connector IDs to connector dicts.

Returns:
List of connector spec dicts that are part of the cycle.
"""
cycle_edges: list[dict[str, _t.Any]] = []
for i, node in enumerate(cycle):
next_node = cycle[(i + 1) % len(cycle)]
for conn_info in connectors.values():
spec = conn_info["spec"]
if spec["source"]["entity"] == node and spec["target"]["entity"] == next_node:
cycle_edges.append(spec)
return cycle_edges


def validate_all_inputs_connected(
process_dict: dict[str, _t.Any],
) -> list[str]:
"""Check that all component inputs are connected.

Args:
process_dict: The output of ``process.dict()``. Uses the ``"components"``
and ``"connectors"`` keys.

Returns:
List of error messages for unconnected inputs.
"""
components: dict[str, dict[str, _t.Any]] = process_dict["components"]
connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"]

connected_inputs: dict[str, set[str]] = defaultdict(set)
for conn_info in connectors.values():
spec = conn_info["spec"]
target_name = spec["target"]["entity"]
target_field = spec["target"]["descriptor"]
connected_inputs[target_name].add(target_field)

errors: list[str] = []
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
all_inputs = set(io.get("inputs", []))
connected = connected_inputs.get(comp_name, set())
unconnected = all_inputs - connected
if unconnected:
errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}")
return errors


def validate_input_events(
process_dict: dict[str, _t.Any],
) -> list[str]:
"""Check that all components with input events have a matching output event producer.

Args:
process_dict: The output of ``process.dict()``. Uses the ``"components"`` key.

Returns:
List of error messages for unmatched input events.
"""
components: dict[str, dict[str, _t.Any]] = process_dict["components"]

all_output_events: set[str] = set()
for comp_data in components.values():
io = comp_data.get("io", {})
all_output_events.update(io.get("output_events", []))

errors: list[str] = []
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
input_events = set(io.get("input_events", []))
unmatched = input_events - all_output_events
if unmatched:
errors.append(
f"Component '{comp_name}' has input events with no producer: {sorted(unmatched)}"
)
return errors


def validate_no_unresolved_cycles(
process_dict: dict[str, _t.Any],
) -> list[str]:
"""Check for circular connections that are not resolved by initial values.

Circular loops are only valid if there are ``initial_values`` set on an
appropriate component input within the loop.

Args:
process_dict: The output of ``process.dict()``. Uses the ``"components"``
and ``"connectors"`` keys.

Returns:
List of error messages for unresolved circular connections.
"""
components: dict[str, dict[str, _t.Any]] = process_dict["components"]
connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"]

graph = _build_component_graph(connectors)
if not graph:
return []

# Build lookup of component initial_values by name
initial_values_by_comp: dict[str, set[str]] = {}
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
iv = io.get("initial_values", {})
if iv:
initial_values_by_comp[comp_name] = set(iv.keys())

errors: list[str] = []
for cycle in simple_cycles(graph):
cycle_edges = _get_edges_in_cycle(cycle, connectors)
cycle_resolved = False
for edge in cycle_edges:
target_comp = edge["target"]["entity"]
target_field = edge["target"]["descriptor"]
if target_comp in initial_values_by_comp:
if target_field in initial_values_by_comp[target_comp]:
cycle_resolved = True
break
if not cycle_resolved:
cycle_str = " -> ".join(cycle + [cycle[0]])
errors.append(
f"Circular connection detected without initial values: {cycle_str}. "
f"Set initial_values on a component input within the loop to resolve."
)
return errors


def validate_process(process_dict: dict[str, _t.Any]) -> list[str]:
"""Run all topology validation checks on a process.

This is the main validation entry point. It accepts the output of
``process.dict()`` and runs every available check, returning a
combined list of error messages.

Args:
process_dict: The output of ``process.dict()``.

Returns:
List of error messages. An empty list indicates a valid topology.
"""
errors: list[str] = []
errors.extend(validate_all_inputs_connected(process_dict))
errors.extend(validate_input_events(process_dict))
errors.extend(validate_no_unresolved_cycles(process_dict))
return errors
30 changes: 29 additions & 1 deletion plugboard/cli/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from plugboard.diagram import MermaidDiagram
from plugboard.process import Process, ProcessBuilder
from plugboard.schemas import ConfigSpec
from plugboard.schemas import ConfigSpec, validate_process
from plugboard.tune import Tuner
from plugboard.utils import add_sys_path, run_coro_sync

Expand Down Expand Up @@ -164,3 +164,31 @@ def diagram(
diagram = MermaidDiagram.from_process(process)
md = Markdown(f"```\n{diagram.diagram}\n```\n[Editable diagram]({diagram.url}) (external link)")
print(md)


@app.command()
def validate(
config: Annotated[
Path,
typer.Argument(
exists=True,
file_okay=True,
dir_okay=False,
writable=False,
readable=True,
resolve_path=True,
help="Path to the YAML configuration file.",
),
],
) -> None:
"""Validate a Plugboard process configuration."""
config_spec = _read_yaml(config)
with add_sys_path(config.parent):
process = _build_process(config_spec)
errors = validate_process(process.dict())
if errors:
stderr.print("[red]Validation failed:[/red]")
for error in errors:
stderr.print(f" • {error}")
raise typer.Exit(1)
print("[green]Validation passed[/green]")
Loading
Loading