Skip to content

Commit 26c0cb5

Browse files
Copilottoby-coleman
andcommitted
Add process topology validation: Johnson's algorithm, cycle detection, and validation utilities
- Add _graph.py with Johnson's cycle-finding algorithm (no additional dependencies) - Add _validation.py with validation utilities for process topology: - validate_all_inputs_connected: check all component inputs are connected - validate_input_events: check input events have matching output producers - validate_no_unresolved_cycles: check circular connections have initial_values - Add model validator on ProcessSpec for circular connection detection - Export new utilities from plugboard_schemas - Add comprehensive unit tests Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com>
1 parent a057bd7 commit 26c0cb5

5 files changed

Lines changed: 656 additions & 0 deletions

File tree

plugboard-schemas/plugboard_schemas/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
from importlib.metadata import version
1010

1111
from ._common import PlugboardBaseModel
12+
from ._graph import simple_cycles
13+
from ._validation import (
14+
ValidationError,
15+
validate_all_inputs_connected,
16+
validate_input_events,
17+
validate_no_unresolved_cycles,
18+
)
1219
from .component import ComponentArgsDict, ComponentArgsSpec, ComponentSpec, Resource
1320
from .config import ConfigSpec, ProcessConfigSpec
1421
from .connector import (
@@ -85,4 +92,9 @@
8592
"TuneArgsDict",
8693
"TuneArgsSpec",
8794
"TuneSpec",
95+
"ValidationError",
96+
"simple_cycles",
97+
"validate_all_inputs_connected",
98+
"validate_input_events",
99+
"validate_no_unresolved_cycles",
88100
]
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Graph algorithms for topology validation.
2+
3+
Implements Johnson's algorithm for finding all simple cycles in a directed graph,
4+
along with helper functions for strongly connected components.
5+
6+
References:
7+
Donald B Johnson. "Finding all the elementary circuits of a directed graph."
8+
SIAM Journal on Computing. 1975.
9+
"""
10+
11+
from collections import defaultdict
12+
from collections.abc import Generator
13+
14+
15+
def simple_cycles(graph: dict[str, set[str]]) -> Generator[list[str], None, None]:
16+
"""Find all simple cycles in a directed graph using Johnson's algorithm.
17+
18+
Args:
19+
graph: A dictionary mapping each vertex to a set of its neighbours.
20+
21+
Yields:
22+
Each elementary cycle as a list of vertices.
23+
"""
24+
graph = {v: set(nbrs) for v, nbrs in graph.items()}
25+
sccs = _strongly_connected_components(graph)
26+
while sccs:
27+
scc = sccs.pop()
28+
startnode = scc.pop()
29+
path = [startnode]
30+
blocked: set[str] = set()
31+
closed: set[str] = set()
32+
blocked.add(startnode)
33+
B: dict[str, set[str]] = defaultdict(set)
34+
stack: list[tuple[str, list[str]]] = [(startnode, list(graph[startnode]))]
35+
while stack:
36+
thisnode, nbrs = stack[-1]
37+
if nbrs:
38+
nextnode = nbrs.pop()
39+
if nextnode == startnode:
40+
yield path[:]
41+
closed.update(path)
42+
elif nextnode not in blocked:
43+
path.append(nextnode)
44+
stack.append((nextnode, list(graph[nextnode])))
45+
closed.discard(nextnode)
46+
blocked.add(nextnode)
47+
continue
48+
if not nbrs:
49+
if thisnode in closed:
50+
_unblock(thisnode, blocked, B)
51+
else:
52+
for nbr in graph[thisnode]:
53+
if thisnode not in B[nbr]:
54+
B[nbr].add(thisnode)
55+
stack.pop()
56+
path.pop()
57+
_remove_node(graph, startnode)
58+
H = _subgraph(graph, set(scc))
59+
sccs.extend(_strongly_connected_components(H))
60+
61+
62+
def _unblock(thisnode: str, blocked: set[str], B: dict[str, set[str]]) -> None:
63+
"""Unblock a node and recursively unblock nodes in its B set."""
64+
stack = {thisnode}
65+
while stack:
66+
node = stack.pop()
67+
if node in blocked:
68+
blocked.remove(node)
69+
stack.update(B[node])
70+
B[node].clear()
71+
72+
73+
def _strongly_connected_components(graph: dict[str, set[str]]) -> list[set[str]]:
74+
"""Find all strongly connected components using Tarjan's algorithm.
75+
76+
Args:
77+
graph: A dictionary mapping each vertex to a set of its neighbours.
78+
79+
Returns:
80+
A list of sets, each containing the vertices of a strongly connected component.
81+
"""
82+
index_counter = [0]
83+
stack: list[str] = []
84+
lowlink: dict[str, int] = {}
85+
index: dict[str, int] = {}
86+
result: list[set[str]] = []
87+
88+
def _strong_connect(node: str) -> None:
89+
index[node] = index_counter[0]
90+
lowlink[node] = index_counter[0]
91+
index_counter[0] += 1
92+
stack.append(node)
93+
94+
for successor in graph.get(node, set()):
95+
if successor not in index:
96+
_strong_connect(successor)
97+
lowlink[node] = min(lowlink[node], lowlink[successor])
98+
elif successor in stack:
99+
lowlink[node] = min(lowlink[node], index[successor])
100+
101+
if lowlink[node] == index[node]:
102+
connected_component: set[str] = set()
103+
while True:
104+
successor = stack.pop()
105+
connected_component.add(successor)
106+
if successor == node:
107+
break
108+
result.append(connected_component)
109+
110+
for node in graph:
111+
if node not in index:
112+
_strong_connect(node)
113+
114+
return result
115+
116+
117+
def _remove_node(graph: dict[str, set[str]], target: str) -> None:
118+
"""Remove a node and all its edges from the graph."""
119+
del graph[target]
120+
for nbrs in graph.values():
121+
nbrs.discard(target)
122+
123+
124+
def _subgraph(graph: dict[str, set[str]], vertices: set[str]) -> dict[str, set[str]]:
125+
"""Get the subgraph induced by a set of vertices."""
126+
return {v: graph[v] & vertices for v in vertices}
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""Validation utilities for `ProcessSpec` objects.
2+
3+
Provides functions to validate process topology including:
4+
- Checking that all component inputs are connected
5+
- Checking that input events have matching output event producers
6+
- Checking for circular connections that require initial values
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from collections import defaultdict
12+
import typing as _t
13+
14+
from ._graph import simple_cycles
15+
16+
17+
if _t.TYPE_CHECKING:
18+
from .component import ComponentSpec
19+
from .connector import ConnectorSpec
20+
21+
22+
class ValidationError(Exception):
23+
"""Raised when a process specification fails validation."""
24+
25+
pass
26+
27+
28+
def _build_component_graph(
29+
connectors: list[ConnectorSpec],
30+
) -> dict[str, set[str]]:
31+
"""Build a directed graph of component connections from connector specs.
32+
33+
Args:
34+
connectors: List of connector specifications.
35+
36+
Returns:
37+
A dictionary mapping source component names to sets of target component names.
38+
"""
39+
graph: dict[str, set[str]] = defaultdict(set)
40+
for conn in connectors:
41+
source_entity = conn.source.entity
42+
target_entity = conn.target.entity
43+
if source_entity != target_entity:
44+
graph[source_entity].add(target_entity)
45+
# Ensure target is in graph even with no outgoing edges
46+
if target_entity not in graph:
47+
graph[target_entity] = set()
48+
return dict(graph)
49+
50+
51+
def _get_edges_in_cycle(
52+
cycle: list[str],
53+
connectors: list[ConnectorSpec],
54+
) -> list[ConnectorSpec]:
55+
"""Get all connector specs that form edges within a cycle.
56+
57+
Args:
58+
cycle: List of component names forming a cycle.
59+
connectors: All connector specifications.
60+
61+
Returns:
62+
List of connector specs that are part of the cycle.
63+
"""
64+
cycle_edges: list[ConnectorSpec] = []
65+
cycle_set = set(cycle)
66+
for i, node in enumerate(cycle):
67+
next_node = cycle[(i + 1) % len(cycle)]
68+
for conn in connectors:
69+
if conn.source.entity == node and conn.target.entity == next_node:
70+
cycle_edges.append(conn)
71+
return [c for c in cycle_edges if c.source.entity in cycle_set and c.target.entity in cycle_set]
72+
73+
74+
def validate_all_inputs_connected(
75+
components: dict[str, dict[str, _t.Any]],
76+
connectors: list[ConnectorSpec],
77+
) -> list[str]:
78+
"""Check that all component inputs are connected.
79+
80+
Args:
81+
components: Dictionary mapping component names to their IO info.
82+
Each value must have an ``"inputs"`` key with a list of input field names.
83+
connectors: List of connector specifications.
84+
85+
Returns:
86+
List of error messages for unconnected inputs.
87+
"""
88+
# Build mapping of which component inputs are connected
89+
connected_inputs: dict[str, set[str]] = defaultdict(set)
90+
for conn in connectors:
91+
target_name = conn.target.entity
92+
target_field = conn.target.descriptor
93+
connected_inputs[target_name].add(target_field)
94+
95+
errors: list[str] = []
96+
for comp_name, comp_info in components.items():
97+
all_inputs = set(comp_info.get("inputs", []))
98+
connected = connected_inputs.get(comp_name, set())
99+
unconnected = all_inputs - connected
100+
if unconnected:
101+
errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}")
102+
return errors
103+
104+
105+
def validate_input_events(
106+
components: dict[str, dict[str, _t.Any]],
107+
) -> list[str]:
108+
"""Check that all components with input events have a matching output event producer.
109+
110+
Args:
111+
components: Dictionary mapping component names to their IO info.
112+
Each value must have ``"input_events"`` and ``"output_events"`` keys
113+
with lists of event type strings.
114+
115+
Returns:
116+
List of error messages for unmatched input events.
117+
"""
118+
# Collect all output event types across all components
119+
all_output_events: set[str] = set()
120+
for comp_info in components.values():
121+
all_output_events.update(comp_info.get("output_events", []))
122+
123+
errors: list[str] = []
124+
for comp_name, comp_info in components.items():
125+
input_events = set(comp_info.get("input_events", []))
126+
unmatched = input_events - all_output_events
127+
if unmatched:
128+
errors.append(
129+
f"Component '{comp_name}' has input events with no producer: {sorted(unmatched)}"
130+
)
131+
return errors
132+
133+
134+
def validate_no_unresolved_cycles(
135+
components: list[ComponentSpec],
136+
connectors: list[ConnectorSpec],
137+
) -> list[str]:
138+
"""Check for circular connections that are not resolved by initial values.
139+
140+
Circular loops are only valid if there are ``initial_values`` set on an
141+
appropriate component input within the loop.
142+
143+
Args:
144+
components: List of component specifications.
145+
connectors: List of connector specifications.
146+
147+
Returns:
148+
List of error messages for unresolved circular connections.
149+
"""
150+
graph = _build_component_graph(connectors)
151+
if not graph:
152+
return []
153+
154+
# Build lookup of component initial_values by name
155+
initial_values_by_comp: dict[str, set[str]] = {}
156+
for comp in components:
157+
if comp.args.initial_values:
158+
initial_values_by_comp[comp.args.name] = set(comp.args.initial_values.keys())
159+
160+
errors: list[str] = []
161+
for cycle in simple_cycles(graph):
162+
# Check if any edge in the cycle targets a component input with initial_values
163+
cycle_edges = _get_edges_in_cycle(cycle, connectors)
164+
cycle_resolved = False
165+
for edge in cycle_edges:
166+
target_comp = edge.target.entity
167+
target_field = edge.target.descriptor
168+
if target_comp in initial_values_by_comp:
169+
if target_field in initial_values_by_comp[target_comp]:
170+
cycle_resolved = True
171+
break
172+
if not cycle_resolved:
173+
cycle_str = " -> ".join(cycle + [cycle[0]])
174+
errors.append(
175+
f"Circular connection detected without initial values: {cycle_str}. "
176+
f"Set initial_values on a component input within the loop to resolve."
177+
)
178+
return errors

plugboard-schemas/plugboard_schemas/process.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing_extensions import Self
88

99
from ._common import PlugboardBaseModel
10+
from ._validation import validate_no_unresolved_cycles
1011
from .component import ComponentSpec
1112
from .connector import DEFAULT_CONNECTOR_CLS_PATH, ConnectorBuilderSpec, ConnectorSpec
1213
from .state import DEFAULT_STATE_BACKEND_CLS_PATH, RAY_STATE_BACKEND_CLS_PATH, StateBackendSpec
@@ -77,6 +78,14 @@ def _set_default_state_backend(self: Self) -> Self:
7778
self.args.state.type = RAY_STATE_BACKEND_CLS_PATH
7879
return self
7980

81+
@model_validator(mode="after")
82+
def _validate_no_unresolved_cycles(self: Self) -> Self:
83+
"""Validate that circular connections have initial_values set."""
84+
errors = validate_no_unresolved_cycles(self.args.components, self.args.connectors)
85+
if errors:
86+
raise ValueError("\n".join(errors))
87+
return self
88+
8089
@field_validator("type", mode="before")
8190
@classmethod
8291
def _validate_type(cls, value: _t.Any) -> str:

0 commit comments

Comments
 (0)