Skip to content

Commit 97cc787

Browse files
test(conformance): drive 0011 parallel-branches
Extends the conformance harness so the eight spec fixtures introduced by proposal 0011 parse and run end-to-end: - pipeline-utilities/032-038 (parallel-branches basic, fail-fast, collect, different-state-schemas, with-branch-middleware-retry, determinism, compose-with-fan-out). - graph-engine/021-observer-branch-name (NodeEvent.branch_name). Harness changes: - ParallelBranchSpec / ParallelBranchesSpec models on NodeSpec; sleep_ms node-companion modifier; recoverable_state expected key on the pipeline-utilities expected discriminator. - adapter.build_graph dispatches the parallel_branches directive to builder.add_parallel_branches_node, wrapping the result in a tracing variant so the execution-order trace records the dispatcher as one engine step. - test_pipeline_utilities driver: lifts the fixture-number gate from 23 to 38; loads top-level plural subgraphs blocks; translates per-branch middleware lists; wires graph-attached observers per run; asserts parallel-branches observer_event_invariants (branch_started_event_order, alpha_inner_attempt_indices_seen, fan-out-inside-branch invariants, plain-branch invariants); routes fail_fast assertions to surface branch_name + cause_message + recoverable_state alongside category. - test_conformance driver: drops the 021 skip and adds an invariants checker for outermost-events-have-no-branch-name and inner-events-carry-correct-branch-name. The checkpoint-resume fixtures (024-031) move into the deferred set since their cases-shape with first_run_expected_error / resume: blocks is driven by test_checkpoint.py.
1 parent 31fca43 commit 97cc787

6 files changed

Lines changed: 509 additions & 73 deletions

File tree

tests/conformance/adapter.py

Lines changed: 130 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from __future__ import annotations
1212

13+
import asyncio
1314
import copy
1415
from collections.abc import Awaitable, Callable, Mapping, Sequence
1516
from dataclasses import dataclass, field
@@ -19,12 +20,14 @@
1920

2021
from openarmature.graph import (
2122
END,
23+
BranchSpec,
2224
CompiledGraph,
2325
EndSentinel,
2426
ExplicitMapping,
2527
FanOutNode,
2628
FieldNameMatching,
2729
GraphBuilder,
30+
ParallelBranchesNode,
2831
ProjectionStrategy,
2932
Reducer,
3033
State,
@@ -56,6 +59,14 @@ def _parse_type(s: str) -> Any:
5659
return float
5760
if s == "bool":
5861
return bool
62+
# Unparameterized container types — parallel-branches fixtures
63+
# 034/035/037 use ``dict`` and ``list<dict>`` as state-field types
64+
# for accumulator slots (branch_errors, merged_dict, collected_labels)
65+
# where the element shape is heterogeneous across branches.
66+
if s == "dict":
67+
return dict[str, Any]
68+
if s == "list<dict>":
69+
return list[dict[str, Any]]
5970
if s.startswith("list<") and s.endswith(">"):
6071
return list[_parse_type(s[5:-1])]
6172
if s.startswith("dict<") and s.endswith(">"):
@@ -357,6 +368,23 @@ async def fn(_state: Any) -> Mapping[str, Any]:
357368
return fn
358369

359370

371+
def _wrap_with_sleep(
372+
fn: Callable[[Any], Awaitable[Mapping[str, Any]]],
373+
sleep_ms: int,
374+
) -> Callable[[Any], Awaitable[Mapping[str, Any]]]:
375+
# ``sleep_ms`` companion modifier on a NodeSpec — sleep that many
376+
# milliseconds before the wrapped body fires. Used by parallel-branches
377+
# fixtures 033 (slow third branch for fail-fast cancellation) and 037
378+
# (randomized completion timing to verify insertion-order determinism).
379+
delay = sleep_ms / 1000.0
380+
381+
async def fn_with_sleep(state: Any) -> Mapping[str, Any]:
382+
await asyncio.sleep(delay)
383+
return await fn(state)
384+
385+
return fn_with_sleep
386+
387+
360388
@dataclass(frozen=True)
361389
class _TracingFanOutNode(FanOutNode[State, State]):
362390
"""Conformance helper: a FanOutNode that appends its name to a shared
@@ -383,6 +411,24 @@ async def run_with_context(
383411
)
384412

385413

414+
@dataclass(frozen=True)
415+
class _TracingParallelBranchesNode(ParallelBranchesNode[State]):
416+
"""Conformance helper: a ParallelBranchesNode that appends its name
417+
to the shared trace list once when the engine runs it. The
418+
parallel-branches dispatcher itself counts as one engine step from
419+
the parent's POV per §11.6, mirroring the fan-out tracing wrapper."""
420+
421+
trace_list: list[str] = field(default_factory=list[str])
422+
423+
async def run_with_context(
424+
self,
425+
state: State,
426+
context: _InvocationContext,
427+
) -> Mapping[str, Any]:
428+
self.trace_list.append(self.name)
429+
return await super().run_with_context(state, context)
430+
431+
386432
@dataclass(frozen=True)
387433
class _TracingSubgraphNode(SubgraphNode[State, State]):
388434
"""Conformance helper: a SubgraphNode that appends its name to a shared
@@ -457,6 +503,7 @@ def build_graph(
457503
node_middleware: Mapping[str, Sequence[Any]] | None = None,
458504
graph_middleware: Sequence[Any] | None = None,
459505
fan_out_instance_middleware: Mapping[str, Sequence[Any]] | None = None,
506+
parallel_branches_branch_middleware: Mapping[str, Mapping[str, Sequence[Any]]] | None = None,
460507
) -> BuiltGraph:
461508
"""Translate a graph-shaped fixture block into a `BuiltGraph`.
462509
@@ -486,6 +533,7 @@ def build_graph(
486533
subgraphs = subgraphs or {}
487534
node_middleware = node_middleware or {}
488535
fan_out_instance_middleware = fan_out_instance_middleware or {}
536+
parallel_branches_branch_middleware = parallel_branches_branch_middleware or {}
489537

490538
for mw in graph_middleware or ():
491539
builder.add_middleware(mw)
@@ -505,7 +553,8 @@ def build_graph(
505553
trace_list=trace,
506554
middleware=per_node_mw,
507555
)
508-
elif "fan_out" in node_spec:
556+
continue
557+
if "fan_out" in node_spec:
509558
_add_fan_out_node(
510559
builder,
511560
node_name,
@@ -514,55 +563,47 @@ def build_graph(
514563
trace,
515564
instance_middleware=fan_out_instance_middleware.get(node_name, ()),
516565
)
517-
elif "raises" in node_spec:
518-
builder.add_node(
566+
continue
567+
if "parallel_branches" in node_spec:
568+
_add_parallel_branches_node(
569+
builder,
519570
node_name,
520-
_make_raising_fn(node_name, node_spec["raises"], trace),
521-
middleware=per_node_mw,
571+
node_spec["parallel_branches"],
572+
subgraphs,
573+
trace,
574+
branch_middleware=parallel_branches_branch_middleware.get(node_name, {}),
522575
)
576+
continue
577+
578+
body: Callable[[Any], Awaitable[Mapping[str, Any]]]
579+
if "raises" in node_spec:
580+
body = _make_raising_fn(node_name, node_spec["raises"], trace)
523581
elif "flaky" in node_spec:
524-
builder.add_node(
525-
node_name,
526-
_make_flaky_fn(node_name, node_spec["flaky"], trace),
527-
middleware=per_node_mw,
528-
)
582+
body = _make_flaky_fn(node_name, node_spec["flaky"], trace)
529583
elif "flaky_by_index" in node_spec:
530-
builder.add_node(
531-
node_name,
532-
_make_flaky_by_index_fn(node_name, node_spec["flaky_by_index"], trace),
533-
middleware=per_node_mw,
534-
)
584+
body = _make_flaky_by_index_fn(node_name, node_spec["flaky_by_index"], trace)
535585
elif "flaky_instance_only" in node_spec:
536-
builder.add_node(
537-
node_name,
538-
_make_flaky_instance_only_fn(node_name, node_spec["flaky_instance_only"], trace),
539-
middleware=per_node_mw,
540-
)
586+
body = _make_flaky_instance_only_fn(node_name, node_spec["flaky_instance_only"], trace)
541587
elif "update" in node_spec:
542-
builder.add_node(
543-
node_name,
544-
_make_update_fn(node_name, node_spec["update"], trace),
545-
middleware=per_node_mw,
546-
)
588+
body = _make_update_fn(node_name, node_spec["update"], trace)
547589
elif "update_pure" in node_spec:
548-
builder.add_node(
549-
node_name,
550-
_make_pure_update_fn(node_name, node_spec["update_pure"], trace),
551-
middleware=per_node_mw,
552-
)
590+
body = _make_pure_update_fn(node_name, node_spec["update_pure"], trace)
553591
elif "update_from_field" in node_spec:
554-
builder.add_node(
555-
node_name,
556-
_make_update_from_field_fn(node_name, node_spec["update_from_field"], trace),
557-
middleware=per_node_mw,
558-
)
592+
body = _make_update_from_field_fn(node_name, node_spec["update_from_field"], trace)
559593
else:
560594
raise ValueError(
561595
f"node {node_name!r} has no recognized directive "
562596
"(update / update_pure / update_from_field / raises / flaky / "
563-
"flaky_by_index / flaky_instance_only / fan_out / subgraph)"
597+
"flaky_by_index / flaky_instance_only / fan_out / parallel_branches / "
598+
"subgraph)"
564599
)
565600

601+
sleep_ms = node_spec.get("sleep_ms")
602+
if sleep_ms is not None:
603+
body = _wrap_with_sleep(body, int(sleep_ms))
604+
605+
builder.add_node(node_name, body, middleware=per_node_mw)
606+
566607
for edge_spec in spec.get("edges", []):
567608
source = edge_spec["from"]
568609
if "to" in edge_spec:
@@ -623,6 +664,8 @@ def _record_event(event: NodeEvent) -> dict[str, Any]:
623664
rec["error"] = event.error.category
624665
if event.fan_out_index is not None:
625666
rec["fan_out_index"] = event.fan_out_index
667+
if event.branch_name is not None:
668+
rec["branch_name"] = event.branch_name
626669
return rec
627670

628671

@@ -734,6 +777,57 @@ def _add_fan_out_node(
734777
)
735778

736779

780+
def _add_parallel_branches_node(
781+
builder: GraphBuilder[Any],
782+
node_name: str,
783+
cfg: Mapping[str, Any],
784+
subgraphs: Mapping[str, CompiledGraph[State]],
785+
trace: list[str],
786+
*,
787+
branch_middleware: Mapping[str, Sequence[Any]],
788+
) -> None:
789+
"""Translate a fixture's ``parallel_branches:`` block into a
790+
``builder.add_parallel_branches_node`` call.
791+
792+
Each branch's ``subgraph`` name resolves against the shared
793+
``subgraphs`` registry (built from the fixture's top-level
794+
``subgraphs:`` block). ``branch_middleware`` maps branch-name to a
795+
pre-translated middleware list; the test driver populates it from
796+
each branch's ``middleware:`` block.
797+
"""
798+
branches_cfg = cast("dict[str, dict[str, Any]]", cfg["branches"])
799+
branches: dict[str, BranchSpec[Any]] = {}
800+
for branch_name, branch_cfg in branches_cfg.items():
801+
sub_compiled = subgraphs[branch_cfg["subgraph"]]
802+
branches[branch_name] = BranchSpec(
803+
subgraph=sub_compiled,
804+
inputs=dict(branch_cfg.get("inputs") or {}),
805+
outputs=dict(branch_cfg.get("outputs") or {}),
806+
middleware=tuple(branch_middleware.get(branch_name, ())),
807+
)
808+
809+
builder.add_parallel_branches_node(
810+
node_name,
811+
branches=branches,
812+
error_policy=cfg.get("error_policy", "fail_fast"),
813+
errors_field=cfg.get("errors_field"),
814+
)
815+
816+
# Swap the registered node for a tracing variant so the
817+
# conformance trace records the dispatcher as one engine step. The
818+
# builder's validation has already run; we only replace the stored
819+
# Node instance.
820+
original = cast("ParallelBranchesNode[State]", builder._nodes[node_name])
821+
builder._nodes[node_name] = _TracingParallelBranchesNode(
822+
name=original.name,
823+
branches=original.branches,
824+
error_policy=original.error_policy,
825+
errors_field=original.errors_field,
826+
middleware=original.middleware,
827+
trace_list=trace,
828+
)
829+
830+
737831
def _resolve_callable_int_resolver(cfg: Mapping[str, Any]) -> Callable[[Any], int]:
738832
"""Build a state-reader callable from a fixture's callable config.
739833

tests/conformance/harness/directives.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,33 @@ class FanOutSpec(_AllowExtras):
235235
instance_middleware: list[MiddlewareSpec] | None = None
236236

237237

238+
class ParallelBranchSpec(_AllowExtras):
239+
"""One entry inside a ``parallel_branches.branches`` mapping.
240+
241+
Permissive on extras because fixtures may carry extra knobs
242+
(e.g., per-branch annotations the harness ignores).
243+
"""
244+
245+
subgraph: str
246+
inputs: dict[str, str] | None = None
247+
outputs: dict[str, str] | None = None
248+
middleware: list[MiddlewareSpec] | None = None
249+
250+
251+
class ParallelBranchesSpec(_AllowExtras):
252+
"""``parallel_branches:`` block on a NodeSpec (pipeline-utilities §11).
253+
254+
Mirrors :class:`FanOutSpec` but topology-driven: M heterogeneous
255+
branches, each referencing a different compiled subgraph by name
256+
against the case's top-level ``subgraphs:`` block. Branch insertion
257+
order is preserved per §11.8.
258+
"""
259+
260+
branches: dict[str, ParallelBranchSpec]
261+
error_policy: Literal["fail_fast", "collect"] | None = None
262+
errors_field: str | None = None
263+
264+
238265
class CallsLlmSpec(_AllowExtras):
239266
"""LLM-using node: sends ``messages`` to the harness's mock provider
240267
and stores the response (assistant content) in ``stores_response_in``.
@@ -294,6 +321,7 @@ class NodeSpec(_ForbidExtras):
294321
raises: str | None = None
295322
subgraph: str | None = None
296323
fan_out: FanOutSpec | None = None
324+
parallel_branches: ParallelBranchesSpec | None = None
297325
flaky: FlakySpec | None = None
298326
flaky_by_index: FlakyByIndexSpec | None = None
299327
flaky_per_index: FlakyPerIndexSpec | None = None
@@ -309,6 +337,13 @@ class NodeSpec(_ForbidExtras):
309337
also_emits_via_global_tracer: GlobalTracerSpec | None = None
310338
# Pair with ``raises`` to specify the error category (graph-engine §4).
311339
error_category: str | None = None
340+
# Parallel-branches fixtures (033, 037): the node sleeps this many
341+
# milliseconds before its update fires. Used to force deterministic
342+
# branch-completion ordering (037 — different branches finish at
343+
# different wall-clock times yet final state must be insertion-order
344+
# deterministic per §11.8) and to slow a third branch so fail-fast
345+
# cancellation has time to land before it finishes (033).
346+
sleep_ms: int | None = None
312347

313348
_PRIMARY_FIELDS = (
314349
"update",
@@ -318,6 +353,7 @@ class NodeSpec(_ForbidExtras):
318353
"raises",
319354
"subgraph",
320355
"fan_out",
356+
"parallel_branches",
321357
"flaky",
322358
"flaky_by_index",
323359
"flaky_per_index",

tests/conformance/harness/expectations.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ class PipelineUtilitiesExpected(_ForbidExtras):
141141
# - dict[recorder_name, list[record]] when multiple recorders (001).
142142
# - list[record] flat when a single recorder.
143143
trace_records: Any = None
144+
# Parallel-branches fixtures (032-038). On fail_fast,
145+
# ``recoverable_state`` carries the pre-entry parent state
146+
# snapshot per spec §11.5; the harness asserts it equals the
147+
# ``recoverable_state`` attached to the raised
148+
# ``ParallelBranchesBranchFailed``.
149+
recoverable_state: dict[str, Any] | None = None
144150

145151

146152
# ---------------------------------------------------------------------------
@@ -201,6 +207,7 @@ class ObservabilityExpected(_ForbidExtras):
201207
"timing_records",
202208
"trace_records",
203209
"expected_observer_event",
210+
"recoverable_state",
204211
}
205212
)
206213
_OBSERVABILITY_KEYS = frozenset(

0 commit comments

Comments
 (0)