Skip to content

Commit b40300f

Browse files
Copilottoby-coleman
andcommitted
Refactor validators to runtime checks accepting process.dict() output
- Remove Pydantic model validator from ProcessSpec - Refactor all validators to accept process.dict() format - Add validate_process() entry point combining all checks - Call validate_process() in Process.init() raising ValidationError - Update plugboard.schemas re-exports - Update tests for new runtime validation approach Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com>
1 parent fc0bf55 commit b40300f

7 files changed

Lines changed: 326 additions & 245 deletions

File tree

plugboard-schemas/plugboard_schemas/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
validate_all_inputs_connected,
1515
validate_input_events,
1616
validate_no_unresolved_cycles,
17+
validate_process,
1718
)
1819
from .component import ComponentArgsDict, ComponentArgsSpec, ComponentSpec, Resource
1920
from .config import ConfigSpec, ProcessConfigSpec
@@ -95,4 +96,5 @@
9596
"validate_all_inputs_connected",
9697
"validate_input_events",
9798
"validate_no_unresolved_cycles",
99+
"validate_process",
98100
]
Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
"""Validation utilities for `ProcessSpec` objects.
1+
"""Validation utilities for process topology.
22
33
Provides functions to validate process topology including:
44
- Checking that all component inputs are connected
55
- Checking that input events have matching output event producers
66
- Checking for circular connections that require initial values
7+
8+
All validators accept the output of ``process.dict()`` or the relevant
9+
sub-structures thereof.
710
"""
811

912
from __future__ import annotations
@@ -14,80 +17,79 @@
1417
from ._graph import simple_cycles
1518

1619

17-
if _t.TYPE_CHECKING:
18-
from .component import ComponentSpec
19-
from .connector import ConnectorSpec
20-
21-
2220
def _build_component_graph(
23-
connectors: list[ConnectorSpec],
21+
connectors: dict[str, dict[str, _t.Any]],
2422
) -> dict[str, set[str]]:
25-
"""Build a directed graph of component connections from connector specs.
23+
"""Build a directed graph of component connections from connector dicts.
2624
2725
Args:
28-
connectors: List of connector specifications.
26+
connectors: Dictionary mapping connector IDs to connector dicts,
27+
as returned by ``process.dict()["connectors"]``.
2928
3029
Returns:
3130
A dictionary mapping source component names to sets of target component names.
3231
"""
3332
graph: dict[str, set[str]] = defaultdict(set)
34-
for conn in connectors:
35-
source_entity = conn.source.entity
36-
target_entity = conn.target.entity
33+
for conn_info in connectors.values():
34+
spec = conn_info["spec"]
35+
source_entity = spec["source"]["entity"]
36+
target_entity = spec["target"]["entity"]
3737
if source_entity != target_entity:
3838
graph[source_entity].add(target_entity)
39-
# Ensure target is in graph even with no outgoing edges
4039
if target_entity not in graph:
4140
graph[target_entity] = set()
4241
return dict(graph)
4342

4443

4544
def _get_edges_in_cycle(
4645
cycle: list[str],
47-
connectors: list[ConnectorSpec],
48-
) -> list[ConnectorSpec]:
49-
"""Get all connector specs that form edges within a cycle.
46+
connectors: dict[str, dict[str, _t.Any]],
47+
) -> list[dict[str, _t.Any]]:
48+
"""Get all connector spec dicts that form edges within a cycle.
5049
5150
Args:
5251
cycle: List of component names forming a cycle.
53-
connectors: All connector specifications.
52+
connectors: Dictionary mapping connector IDs to connector dicts.
5453
5554
Returns:
56-
List of connector specs that are part of the cycle.
55+
List of connector spec dicts that are part of the cycle.
5756
"""
58-
cycle_edges: list[ConnectorSpec] = []
57+
cycle_edges: list[dict[str, _t.Any]] = []
5958
for i, node in enumerate(cycle):
6059
next_node = cycle[(i + 1) % len(cycle)]
61-
for conn in connectors:
62-
if conn.source.entity == node and conn.target.entity == next_node:
63-
cycle_edges.append(conn)
60+
for conn_info in connectors.values():
61+
spec = conn_info["spec"]
62+
if spec["source"]["entity"] == node and spec["target"]["entity"] == next_node:
63+
cycle_edges.append(spec)
6464
return cycle_edges
6565

6666

6767
def validate_all_inputs_connected(
68-
components: dict[str, dict[str, _t.Any]],
69-
connectors: list[ConnectorSpec],
68+
process_dict: dict[str, _t.Any],
7069
) -> list[str]:
7170
"""Check that all component inputs are connected.
7271
7372
Args:
74-
components: Dictionary mapping component names to their IO info.
75-
Each value must have an ``"inputs"`` key with a list of input field names.
76-
connectors: List of connector specifications.
73+
process_dict: The output of ``process.dict()``. Uses the ``"components"``
74+
and ``"connectors"`` keys.
7775
7876
Returns:
7977
List of error messages for unconnected inputs.
8078
"""
81-
# Build mapping of which component inputs are connected
79+
components: dict[str, dict[str, _t.Any]] = process_dict["components"]
80+
connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"]
81+
8282
connected_inputs: dict[str, set[str]] = defaultdict(set)
83-
for conn in connectors:
84-
target_name = conn.target.entity
85-
target_field = conn.target.descriptor
83+
for conn_info in connectors.values():
84+
spec = conn_info["spec"]
85+
target_name = spec["target"]["entity"]
86+
target_field = spec["target"]["descriptor"]
8687
connected_inputs[target_name].add(target_field)
8788

8889
errors: list[str] = []
89-
for comp_name, comp_info in components.items():
90-
all_inputs = set(comp_info.get("inputs", []))
90+
for comp_name, comp_data in components.items():
91+
io = comp_data.get("io", {})
92+
all_inputs = set(io.get("inputs", []))
9193
connected = connected_inputs.get(comp_name, set())
9294
unconnected = all_inputs - connected
9395
if unconnected:
@@ -96,26 +98,27 @@ def validate_all_inputs_connected(
9698

9799

98100
def validate_input_events(
99-
components: dict[str, dict[str, _t.Any]],
101+
process_dict: dict[str, _t.Any],
100102
) -> list[str]:
101103
"""Check that all components with input events have a matching output event producer.
102104
103105
Args:
104-
components: Dictionary mapping component names to their IO info.
105-
Each value must have ``"input_events"`` and ``"output_events"`` keys
106-
with lists of event type strings.
106+
process_dict: The output of ``process.dict()``. Uses the ``"components"`` key.
107107
108108
Returns:
109109
List of error messages for unmatched input events.
110110
"""
111-
# Collect all output event types across all components
111+
components: dict[str, dict[str, _t.Any]] = process_dict["components"]
112+
112113
all_output_events: set[str] = set()
113-
for comp_info in components.values():
114-
all_output_events.update(comp_info.get("output_events", []))
114+
for comp_data in components.values():
115+
io = comp_data.get("io", {})
116+
all_output_events.update(io.get("output_events", []))
115117

116118
errors: list[str] = []
117-
for comp_name, comp_info in components.items():
118-
input_events = set(comp_info.get("input_events", []))
119+
for comp_name, comp_data in components.items():
120+
io = comp_data.get("io", {})
121+
input_events = set(io.get("input_events", []))
119122
unmatched = input_events - all_output_events
120123
if unmatched:
121124
errors.append(
@@ -125,39 +128,42 @@ def validate_input_events(
125128

126129

127130
def validate_no_unresolved_cycles(
128-
components: list[ComponentSpec],
129-
connectors: list[ConnectorSpec],
131+
process_dict: dict[str, _t.Any],
130132
) -> list[str]:
131133
"""Check for circular connections that are not resolved by initial values.
132134
133135
Circular loops are only valid if there are ``initial_values`` set on an
134136
appropriate component input within the loop.
135137
136138
Args:
137-
components: List of component specifications.
138-
connectors: List of connector specifications.
139+
process_dict: The output of ``process.dict()``. Uses the ``"components"``
140+
and ``"connectors"`` keys.
139141
140142
Returns:
141143
List of error messages for unresolved circular connections.
142144
"""
145+
components: dict[str, dict[str, _t.Any]] = process_dict["components"]
146+
connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"]
147+
143148
graph = _build_component_graph(connectors)
144149
if not graph:
145150
return []
146151

147152
# Build lookup of component initial_values by name
148153
initial_values_by_comp: dict[str, set[str]] = {}
149-
for comp in components:
150-
if comp.args.initial_values:
151-
initial_values_by_comp[comp.args.name] = set(comp.args.initial_values.keys())
154+
for comp_name, comp_data in components.items():
155+
io = comp_data.get("io", {})
156+
iv = io.get("initial_values", {})
157+
if iv:
158+
initial_values_by_comp[comp_name] = set(iv.keys())
152159

153160
errors: list[str] = []
154161
for cycle in simple_cycles(graph):
155-
# Check if any edge in the cycle targets a component input with initial_values
156162
cycle_edges = _get_edges_in_cycle(cycle, connectors)
157163
cycle_resolved = False
158164
for edge in cycle_edges:
159-
target_comp = edge.target.entity
160-
target_field = edge.target.descriptor
165+
target_comp = edge["target"]["entity"]
166+
target_field = edge["target"]["descriptor"]
161167
if target_comp in initial_values_by_comp:
162168
if target_field in initial_values_by_comp[target_comp]:
163169
cycle_resolved = True
@@ -169,3 +175,23 @@ def validate_no_unresolved_cycles(
169175
f"Set initial_values on a component input within the loop to resolve."
170176
)
171177
return errors
178+
179+
180+
def validate_process(process_dict: dict[str, _t.Any]) -> list[str]:
181+
"""Run all topology validation checks on a process.
182+
183+
This is the main validation entry point. It accepts the output of
184+
``process.dict()`` and runs every available check, returning a
185+
combined list of error messages.
186+
187+
Args:
188+
process_dict: The output of ``process.dict()``.
189+
190+
Returns:
191+
List of error messages. An empty list indicates a valid topology.
192+
"""
193+
errors: list[str] = []
194+
errors.extend(validate_all_inputs_connected(process_dict))
195+
errors.extend(validate_input_events(process_dict))
196+
errors.extend(validate_no_unresolved_cycles(process_dict))
197+
return errors

plugboard-schemas/plugboard_schemas/process.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing_extensions import Self
88

99
from ._common import PlugboardBaseModel
10-
from ._validation import validate_no_unresolved_cycles
1110
from .component import ComponentSpec
1211
from .connector import DEFAULT_CONNECTOR_CLS_PATH, ConnectorBuilderSpec, ConnectorSpec
1312
from .state import DEFAULT_STATE_BACKEND_CLS_PATH, RAY_STATE_BACKEND_CLS_PATH, StateBackendSpec
@@ -78,14 +77,6 @@ def _set_default_state_backend(self: Self) -> Self:
7877
self.args.state.type = RAY_STATE_BACKEND_CLS_PATH
7978
return self
8079

81-
@model_validator(mode="after")
82-
def _validate_no_unresolved_cycles(self: Self) -> Self:
83-
"""Validate that circular connections have initial_values set."""
84-
errors = validate_no_unresolved_cycles(self.args.components, self.args.connectors)
85-
if errors:
86-
raise ValueError("\n".join(errors))
87-
return self
88-
8980
@field_validator("type", mode="before")
9081
@classmethod
9182
def _validate_type(cls, value: _t.Any) -> str:

plugboard/process/process.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
from plugboard.component import Component
1515
from plugboard.connector import Connector
16-
from plugboard.exceptions import NotInitialisedError
17-
from plugboard.schemas import ConfigSpec, Status
16+
from plugboard.exceptions import NotInitialisedError, ValidationError
17+
from plugboard.schemas import ConfigSpec, Status, validate_process
1818
from plugboard.state import DictStateBackend, StateBackend
1919
from plugboard.utils import DI, ExportMixin, gen_rand_str
2020
from plugboard.utils.async_utils import run_coro_sync
@@ -109,6 +109,9 @@ async def _set_status(self, status: Status, publish: bool = True) -> None:
109109
@abstractmethod
110110
async def init(self) -> None:
111111
"""Performs component initialisation actions."""
112+
errors = validate_process(self.dict())
113+
if errors:
114+
raise ValidationError("\n".join(errors))
112115
self._is_initialised = True
113116
await self._set_status(Status.INIT)
114117

plugboard/schemas/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
TuneArgsDict,
4747
TuneArgsSpec,
4848
TuneSpec,
49+
simple_cycles,
50+
validate_all_inputs_connected,
51+
validate_input_events,
52+
validate_no_unresolved_cycles,
53+
validate_process,
4954
)
5055

5156

@@ -86,4 +91,9 @@
8691
"TuneArgsDict",
8792
"TuneArgsSpec",
8893
"TuneSpec",
94+
"simple_cycles",
95+
"validate_all_inputs_connected",
96+
"validate_input_events",
97+
"validate_no_unresolved_cycles",
98+
"validate_process",
8999
]

tests/integration/test_process_validation.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
from tests.integration.test_process_with_components_run import A, B, C
1616

1717

18-
# TODO: Update these tests when we implement full graph validation
19-
20-
2118
def filter_logs(logs: list[EventDict], field: str, regex: str) -> list[EventDict]:
2219
"""Filters the log output by applying regex to a field."""
2320
pattern = re.compile(regex)
@@ -26,20 +23,15 @@ def filter_logs(logs: list[EventDict], field: str, regex: str) -> list[EventDict
2623

2724
@pytest.mark.asyncio
2825
async def test_missing_connections() -> None:
29-
"""Tests that missing connections are logged."""
26+
"""Tests that missing input connections raise ValidationError."""
3027
p_missing_input = LocalProcess(
3128
components=[A(name="a", iters=10), C(name="c", path="test-out.csv")],
3229
# c.in_1 is not connected
3330
connectors=[AsyncioConnector(spec=ConnectorSpec(source="a.out_1", target="unknown.x"))],
3431
)
35-
with capture_logs() as logs:
32+
with pytest.raises(exceptions.ValidationError, match="unconnected inputs"):
3633
await p_missing_input.init()
3734

38-
# Must contain an error-level log indicating that input is not connected
39-
logs = filter_logs(logs, "log_level", "error")
40-
logs = filter_logs(logs, "event", "Input fields not connected")
41-
assert logs, "Logs do not indicate missing connection"
42-
4335
p_missing_output = LocalProcess(
4436
components=[A(name="a", iters=10), B(name="b")],
4537
# b.out_1 is not connected

0 commit comments

Comments
 (0)