1010
1111from __future__ import annotations
1212
13+ import asyncio
1314import copy
1415from collections .abc import Awaitable , Callable , Mapping , Sequence
1516from dataclasses import dataclass , field
1920
2021from openarmature .graph import (
2122 END ,
23+ BranchSpec ,
2224 CompiledGraph ,
2325 EndSentinel ,
2426 ExplicitMapping ,
2527 FanOutNode ,
2628 FieldNameMatching ,
2729 GraphBuilder ,
30+ ParallelBranchesNode ,
2831 ProjectionStrategy ,
2932 Reducer ,
3033 State ,
@@ -56,6 +59,14 @@ def _parse_type(s: str) -> Any:
5659 return float
5760 if s == "bool" :
5861 return bool
62+ # Unparameterized container types — parallel-branches fixtures
63+ # 034/035/037 use ``dict`` and ``list<dict>`` as state-field types
64+ # for accumulator slots (branch_errors, merged_dict, collected_labels)
65+ # where the element shape is heterogeneous across branches.
66+ if s == "dict" :
67+ return dict [str , Any ]
68+ if s == "list<dict>" :
69+ return list [dict [str , Any ]]
5970 if s .startswith ("list<" ) and s .endswith (">" ):
6071 return list [_parse_type (s [5 :- 1 ])]
6172 if s .startswith ("dict<" ) and s .endswith (">" ):
@@ -357,6 +368,23 @@ async def fn(_state: Any) -> Mapping[str, Any]:
357368 return fn
358369
359370
371+ def _wrap_with_sleep (
372+ fn : Callable [[Any ], Awaitable [Mapping [str , Any ]]],
373+ sleep_ms : int ,
374+ ) -> Callable [[Any ], Awaitable [Mapping [str , Any ]]]:
375+ # ``sleep_ms`` companion modifier on a NodeSpec — sleep that many
376+ # milliseconds before the wrapped body fires. Used by parallel-branches
377+ # fixtures 033 (slow third branch for fail-fast cancellation) and 037
378+ # (randomized completion timing to verify insertion-order determinism).
379+ delay = sleep_ms / 1000.0
380+
381+ async def fn_with_sleep (state : Any ) -> Mapping [str , Any ]:
382+ await asyncio .sleep (delay )
383+ return await fn (state )
384+
385+ return fn_with_sleep
386+
387+
360388@dataclass (frozen = True )
361389class _TracingFanOutNode (FanOutNode [State , State ]):
362390 """Conformance helper: a FanOutNode that appends its name to a shared
@@ -383,6 +411,24 @@ async def run_with_context(
383411 )
384412
385413
414+ @dataclass (frozen = True )
415+ class _TracingParallelBranchesNode (ParallelBranchesNode [State ]):
416+ """Conformance helper: a ParallelBranchesNode that appends its name
417+ to the shared trace list once when the engine runs it. The
418+ parallel-branches dispatcher itself counts as one engine step from
419+ the parent's POV per §11.6, mirroring the fan-out tracing wrapper."""
420+
421+ trace_list : list [str ] = field (default_factory = list [str ])
422+
423+ async def run_with_context (
424+ self ,
425+ state : State ,
426+ context : _InvocationContext ,
427+ ) -> Mapping [str , Any ]:
428+ self .trace_list .append (self .name )
429+ return await super ().run_with_context (state , context )
430+
431+
386432@dataclass (frozen = True )
387433class _TracingSubgraphNode (SubgraphNode [State , State ]):
388434 """Conformance helper: a SubgraphNode that appends its name to a shared
@@ -457,6 +503,7 @@ def build_graph(
457503 node_middleware : Mapping [str , Sequence [Any ]] | None = None ,
458504 graph_middleware : Sequence [Any ] | None = None ,
459505 fan_out_instance_middleware : Mapping [str , Sequence [Any ]] | None = None ,
506+ parallel_branches_branch_middleware : Mapping [str , Mapping [str , Sequence [Any ]]] | None = None ,
460507) -> BuiltGraph :
461508 """Translate a graph-shaped fixture block into a `BuiltGraph`.
462509
@@ -486,6 +533,7 @@ def build_graph(
486533 subgraphs = subgraphs or {}
487534 node_middleware = node_middleware or {}
488535 fan_out_instance_middleware = fan_out_instance_middleware or {}
536+ parallel_branches_branch_middleware = parallel_branches_branch_middleware or {}
489537
490538 for mw in graph_middleware or ():
491539 builder .add_middleware (mw )
@@ -505,7 +553,8 @@ def build_graph(
505553 trace_list = trace ,
506554 middleware = per_node_mw ,
507555 )
508- elif "fan_out" in node_spec :
556+ continue
557+ if "fan_out" in node_spec :
509558 _add_fan_out_node (
510559 builder ,
511560 node_name ,
@@ -514,55 +563,47 @@ def build_graph(
514563 trace ,
515564 instance_middleware = fan_out_instance_middleware .get (node_name , ()),
516565 )
517- elif "raises" in node_spec :
518- builder .add_node (
566+ continue
567+ if "parallel_branches" in node_spec :
568+ _add_parallel_branches_node (
569+ builder ,
519570 node_name ,
520- _make_raising_fn (node_name , node_spec ["raises" ], trace ),
521- middleware = per_node_mw ,
571+ node_spec ["parallel_branches" ],
572+ subgraphs ,
573+ trace ,
574+ branch_middleware = parallel_branches_branch_middleware .get (node_name , {}),
522575 )
576+ continue
577+
578+ body : Callable [[Any ], Awaitable [Mapping [str , Any ]]]
579+ if "raises" in node_spec :
580+ body = _make_raising_fn (node_name , node_spec ["raises" ], trace )
523581 elif "flaky" in node_spec :
524- builder .add_node (
525- node_name ,
526- _make_flaky_fn (node_name , node_spec ["flaky" ], trace ),
527- middleware = per_node_mw ,
528- )
582+ body = _make_flaky_fn (node_name , node_spec ["flaky" ], trace )
529583 elif "flaky_by_index" in node_spec :
530- builder .add_node (
531- node_name ,
532- _make_flaky_by_index_fn (node_name , node_spec ["flaky_by_index" ], trace ),
533- middleware = per_node_mw ,
534- )
584+ body = _make_flaky_by_index_fn (node_name , node_spec ["flaky_by_index" ], trace )
535585 elif "flaky_instance_only" in node_spec :
536- builder .add_node (
537- node_name ,
538- _make_flaky_instance_only_fn (node_name , node_spec ["flaky_instance_only" ], trace ),
539- middleware = per_node_mw ,
540- )
586+ body = _make_flaky_instance_only_fn (node_name , node_spec ["flaky_instance_only" ], trace )
541587 elif "update" in node_spec :
542- builder .add_node (
543- node_name ,
544- _make_update_fn (node_name , node_spec ["update" ], trace ),
545- middleware = per_node_mw ,
546- )
588+ body = _make_update_fn (node_name , node_spec ["update" ], trace )
547589 elif "update_pure" in node_spec :
548- builder .add_node (
549- node_name ,
550- _make_pure_update_fn (node_name , node_spec ["update_pure" ], trace ),
551- middleware = per_node_mw ,
552- )
590+ body = _make_pure_update_fn (node_name , node_spec ["update_pure" ], trace )
553591 elif "update_from_field" in node_spec :
554- builder .add_node (
555- node_name ,
556- _make_update_from_field_fn (node_name , node_spec ["update_from_field" ], trace ),
557- middleware = per_node_mw ,
558- )
592+ body = _make_update_from_field_fn (node_name , node_spec ["update_from_field" ], trace )
559593 else :
560594 raise ValueError (
561595 f"node { node_name !r} has no recognized directive "
562596 "(update / update_pure / update_from_field / raises / flaky / "
563- "flaky_by_index / flaky_instance_only / fan_out / subgraph)"
597+ "flaky_by_index / flaky_instance_only / fan_out / parallel_branches / "
598+ "subgraph)"
564599 )
565600
601+ sleep_ms = node_spec .get ("sleep_ms" )
602+ if sleep_ms is not None :
603+ body = _wrap_with_sleep (body , int (sleep_ms ))
604+
605+ builder .add_node (node_name , body , middleware = per_node_mw )
606+
566607 for edge_spec in spec .get ("edges" , []):
567608 source = edge_spec ["from" ]
568609 if "to" in edge_spec :
@@ -623,6 +664,8 @@ def _record_event(event: NodeEvent) -> dict[str, Any]:
623664 rec ["error" ] = event .error .category
624665 if event .fan_out_index is not None :
625666 rec ["fan_out_index" ] = event .fan_out_index
667+ if event .branch_name is not None :
668+ rec ["branch_name" ] = event .branch_name
626669 return rec
627670
628671
@@ -734,6 +777,57 @@ def _add_fan_out_node(
734777 )
735778
736779
780+ def _add_parallel_branches_node (
781+ builder : GraphBuilder [Any ],
782+ node_name : str ,
783+ cfg : Mapping [str , Any ],
784+ subgraphs : Mapping [str , CompiledGraph [State ]],
785+ trace : list [str ],
786+ * ,
787+ branch_middleware : Mapping [str , Sequence [Any ]],
788+ ) -> None :
789+ """Translate a fixture's ``parallel_branches:`` block into a
790+ ``builder.add_parallel_branches_node`` call.
791+
792+ Each branch's ``subgraph`` name resolves against the shared
793+ ``subgraphs`` registry (built from the fixture's top-level
794+ ``subgraphs:`` block). ``branch_middleware`` maps branch-name to a
795+ pre-translated middleware list; the test driver populates it from
796+ each branch's ``middleware:`` block.
797+ """
798+ branches_cfg = cast ("dict[str, dict[str, Any]]" , cfg ["branches" ])
799+ branches : dict [str , BranchSpec [Any ]] = {}
800+ for branch_name , branch_cfg in branches_cfg .items ():
801+ sub_compiled = subgraphs [branch_cfg ["subgraph" ]]
802+ branches [branch_name ] = BranchSpec (
803+ subgraph = sub_compiled ,
804+ inputs = dict (branch_cfg .get ("inputs" ) or {}),
805+ outputs = dict (branch_cfg .get ("outputs" ) or {}),
806+ middleware = tuple (branch_middleware .get (branch_name , ())),
807+ )
808+
809+ builder .add_parallel_branches_node (
810+ node_name ,
811+ branches = branches ,
812+ error_policy = cfg .get ("error_policy" , "fail_fast" ),
813+ errors_field = cfg .get ("errors_field" ),
814+ )
815+
816+ # Swap the registered node for a tracing variant so the
817+ # conformance trace records the dispatcher as one engine step. The
818+ # builder's validation has already run; we only replace the stored
819+ # Node instance.
820+ original = cast ("ParallelBranchesNode[State]" , builder ._nodes [node_name ])
821+ builder ._nodes [node_name ] = _TracingParallelBranchesNode (
822+ name = original .name ,
823+ branches = original .branches ,
824+ error_policy = original .error_policy ,
825+ errors_field = original .errors_field ,
826+ middleware = original .middleware ,
827+ trace_list = trace ,
828+ )
829+
830+
737831def _resolve_callable_int_resolver (cfg : Mapping [str , Any ]) -> Callable [[Any ], int ]:
738832 """Build a state-reader callable from a fixture's callable config.
739833
0 commit comments