1010
1111Core concepts:
1212 - **Nodes**: Characteristics (PDF, CDF, etc.) with presence and definitiveness rules
13- - **Edges**: Unary computation methods between characteristics
13+ - **Edges**: Computation methods from one-or-many sources to one target
1414 - **Constraints**: Rules that determine when nodes/edges are applicable
1515 - **View**: A filtered subgraph for a specific distribution
1616 - **Definitive characteristics**: Starting points for computations
@@ -64,14 +64,14 @@ class CharacteristicRegistry:
6464 add_characteristic(name, is_definitive, presence_constraint=None, definitive_constraint=None)
6565 Declare a characteristic with presence and optional definitiveness rules.
6666 add_computation(method, label=DEFAULT_COMPUTATION_KEY, constraint=None)
67- Add a unary computation edge between declared nodes.
67+ Add a computation edge between declared nodes.
6868 view(distr)
6969 Create a filtered view for the given distribution.
7070
7171 Notes
7272 -----
7373 - Nodes must be declared before adding computations
74- - Only unary computations (1 source → 1 target) are supported
74+ - Only many-to-one computations (n sources → 1 target) are supported
7575 - No invariant validation happens during mutation; validation occurs when
7676 creating a view with view()
7777 """
@@ -88,10 +88,12 @@ def __init__(self) -> None:
8888 if getattr (self , "_initialized" , False ):
8989 return
9090
91- # Adjacency: src → dst → label → [EdgeMeta]
91+ # Adjacency projection: src → dst → label → [ComputationEdgeMeta]
92+ # For hyperedges (many sources -> one target), the same edge metadata object
93+ # is projected under each source to preserve graph reachability semantics.
9294 self ._adj : dict [
9395 GenericCharacteristicName ,
94- dict [GenericCharacteristicName , dict [LabelName , list [EdgeMeta ]]],
96+ dict [GenericCharacteristicName , dict [LabelName , list [ComputationEdgeMeta ]]],
9597 ] = {}
9698 self ._all_nodes : set [GenericCharacteristicName ] = set ()
9799
@@ -179,12 +181,12 @@ def add_computation(
179181 constraint : GraphPrimitiveConstraint | None = None ,
180182 ) -> None :
181183 """
182- Add a labeled unary computation edge.
184+ Add a labeled computation edge.
183185
184186 Parameters
185187 ----------
186188 method : ComputationMethod
187- Computation object with exactly one source and one target.
189+ Computation object with one-or-many sources and one target.
188190 label : LabelName, default=DEFAULT_COMPUTATION_KEY
189191 Variant label for the edge.
190192 constraint : GraphPrimitiveConstraint, optional
@@ -193,33 +195,36 @@ def add_computation(
193195 Raises
194196 ------
195197 ValueError
196- If method is not unary , or source/target nodes are not declared.
198+ If method has no sources , or source/target nodes are not declared.
197199
198200 Notes
199201 -----
200202 - Multiple edges with different labels can exist between the same nodes
201203 - The first matching edge for each label is kept when creating views
204+ - Hyperedges are represented as projected edges from each source to target,
205+ while preserving one shared underlying computation method.
202206 """
203- if len ( method .sources ) != 1 :
204- raise ValueError ("Only unary computations are supported (1 source → 1 target) ." )
207+ if not method .sources :
208+ raise ValueError ("Computation must define at least one source characteristic ." )
205209
206- src = method .sources [ 0 ]
210+ unique_sources = tuple ( dict . fromkeys ( method .sources ))
207211 dst = method .target
208212
209- if not self ._ensure_node (src ) or not self ._ensure_node (dst ):
213+ if not self ._ensure_node (dst ) or any ( not self ._ensure_node (src ) for src in unique_sources ):
210214 raise ValueError ("Source characteristic or destination characteristic is invalid." )
211215
212- self ._adj [src ].setdefault (dst , {})
216+ edge_meta = ComputationEdgeMeta (
217+ method = method ,
218+ constraint = constraint or GraphPrimitiveConstraint (),
219+ )
220+
213221 # TODO: We need to be careful here if some constraint more general and with the same label
214222 # than other it can consume it. Actually, the same label methods should not intersect their
215223 # constraints
216- self ._adj [src ][dst ].setdefault (label , [])
217- self ._adj [src ][dst ][label ].append (
218- ComputationEdgeMeta (
219- method = method ,
220- constraint = constraint or GraphPrimitiveConstraint (),
221- )
222- )
224+ for src in unique_sources :
225+ self ._adj [src ].setdefault (dst , {})
226+ self ._adj [src ][dst ].setdefault (label , [])
227+ self ._adj [src ][dst ][label ].append (edge_meta )
223228
224229 def add_characteristic (
225230 self ,
@@ -342,41 +347,36 @@ def view(self, distr: Distribution) -> RegistryView:
342347
343348 Notes
344349 -----
345- 1. Filters edges by their constraints
346- 2. Removes edges touching absent nodes
350+ 1. Computes present nodes for the distribution
351+ 2. Filters edges by node presence and edge constraints
347352 3. Adds analytical self-loops from distribution analytical computations
348353 4. Computes definitive nodes from the remaining present nodes
349354 5. Validates graph invariants
350355 """
351- # 1) Filter edges by applicability
356+ # 1) Compute present nodes once and pre-create adjacency.
357+ present_nodes = self ._compute_present_nodes (distr )
352358 adj : dict [
353359 GenericCharacteristicName , dict [GenericCharacteristicName , dict [LabelName , EdgeMeta ]]
354- ] = {}
355- for src , d in self ._adj .items ():
356- for dst , variants in d .items ():
360+ ] = {node : {} for node in present_nodes }
361+
362+ # 2) Filter edges by node presence and applicability.
363+ for src in present_nodes :
364+ for dst , variants in self ._adj .get (src , {}).items ():
365+ if dst not in present_nodes :
366+ continue
357367 kept : dict [LabelName , EdgeMeta ] = {}
358368 for label , metas in variants .items ():
359369 for edge in metas :
360- if edge .constraint .allows (distr ):
370+ if edge .constraint .allows (distr ) and all (
371+ source in present_nodes for source in edge .method .sources
372+ ):
361373 kept [label ] = edge
362374 # TODO: It is possible that there are two edges under the same label
363375 # that fit the same distribution, this should not be the case.
364376 # Taking the first one for now
365377 break
366378 if kept :
367- adj .setdefault (src , {}).setdefault (dst , {}).update (kept )
368-
369- # 2) Filter by node presence
370- present_nodes = self ._compute_present_nodes (distr )
371- if present_nodes :
372- adj = {
373- src : {dst : dict (variants ) for dst , variants in d .items () if dst in present_nodes }
374- for src , d in adj .items ()
375- if src in present_nodes
376- }
377- # Ensure isolated present nodes are preserved
378- for node in present_nodes :
379- adj .setdefault (node , {})
379+ adj [src ][dst ] = kept
380380
381381 # 3) Attach analytical loops
382382 self ._attach_analytical_loops (adj , distr , present_nodes )
@@ -438,6 +438,14 @@ def __init__(
438438 for src , d in adj .items ():
439439 self ._adj [src ] = {dst : dict (variants ) for dst , variants in d .items ()}
440440
441+ self ._rev_adj : dict [GenericCharacteristicName , set [GenericCharacteristicName ]] = {
442+ node : set () for node in self ._adj
443+ }
444+ for src , d in self ._adj .items ():
445+ for dst , variants in d .items ():
446+ if variants :
447+ self ._rev_adj .setdefault (dst , set ()).add (src )
448+
441449 self .definitive_characteristics : set [GenericCharacteristicName ] = set (definitive_nodes )
442450 self .all_characteristics : set [GenericCharacteristicName ] = set (present_nodes )
443451
@@ -504,11 +512,7 @@ def predecessors(self, v: GenericCharacteristicName) -> set[GenericCharacteristi
504512 set of str
505513 Characteristics that can reach v directly.
506514 """
507- res : set [GenericCharacteristicName ] = set ()
508- for src , d in self ._adj .items ():
509- if v in d and d [v ]:
510- res .add (src )
511- return res
515+ return set (self ._rev_adj .get (v , set ()))
512516
513517 def variants (
514518 self , src : GenericCharacteristicName , dst : GenericCharacteristicName
@@ -643,16 +647,8 @@ def _definitive_strongly_connected(self) -> bool:
643647 if fwd != (defs - {start }):
644648 return False
645649
646- # Check reverse reachability
647- seen : set [GenericCharacteristicName ] = {start }
648- stack = [start ]
649- while stack :
650- v = stack .pop ()
651- for w in self .predecessors (v ):
652- if w in defs and w not in seen :
653- seen .add (w )
654- stack .append (w )
655- return seen == defs
650+ rev = self ._reachable_from_many ({start }, allowed = defs , reverse = True )
651+ return rev == (defs - {start })
656652
657653 def _all_indefinitives_reachable_from_definitives (self ) -> bool :
658654 """
@@ -667,9 +663,7 @@ def _all_indefinitives_reachable_from_definitives(self) -> bool:
667663 if not indefs :
668664 return True
669665
670- total : set [GenericCharacteristicName ] = set ()
671- for d in self .definitive_characteristics :
672- total |= self ._reachable_from (d )
666+ total = self ._reachable_from_many (self .definitive_characteristics )
673667 return indefs .issubset (total )
674668
675669 def _exists_path_from_indefinitive_to_definitive (self ) -> bool :
@@ -682,46 +676,78 @@ def _exists_path_from_indefinitive_to_definitive(self) -> bool:
682676 True if such a path exists (which would violate invariants).
683677 """
684678 defs = self .definitive_characteristics
685- return any (self ._reachable_from (i ) & defs for i in self .indefinitive_characteristics )
679+ if not defs :
680+ return False
686681
687- def _reachable_from (
682+ can_reach_definitive = self ._reachable_from_many (defs , reverse = True )
683+ return bool (can_reach_definitive & self .indefinitive_characteristics )
684+
685+ def _reachable_from_many (
688686 self ,
689- src : GenericCharacteristicName ,
687+ sources : set [ GenericCharacteristicName ] ,
690688 * ,
691689 allowed : set [GenericCharacteristicName ] | None = None ,
690+ reverse : bool = False ,
692691 ) -> set [GenericCharacteristicName ]:
693692 """
694- Compute forward reachable nodes from src .
693+ Compute reachable nodes from multiple sources .
695694
696695 Parameters
697696 ----------
698- src : str
699- Starting node .
697+ sources : set of str
698+ Starting nodes .
700699 allowed : set of str, optional
701- Restrict to this set of nodes.
700+ Restrict traversal to this set of nodes.
701+ reverse : bool, default=False
702+ If True, traverse reverse edges.
702703
703704 Returns
704705 -------
705706 set of str
706- Nodes reachable from src (excluding src itself ).
707+ Nodes reachable from sources (excluding sources themselves ).
707708 """
708- if allowed is not None and src not in allowed :
709+ starts = {s for s in sources if allowed is None or s in allowed }
710+ if not starts :
709711 return set ()
710712
711713 visited : set [GenericCharacteristicName ] = set ()
712- stack = [ src ]
714+ stack = list ( starts )
713715 while stack :
714716 v = stack .pop ()
715717 if v in visited :
716718 continue
717719 visited .add (v )
718- for w in self .successors_nodes (v ):
720+ neighbors = self ._rev_adj .get (v , set ()) if reverse else self .successors_nodes (v )
721+ for w in neighbors :
719722 if allowed is not None and w not in allowed :
720723 continue
721724 if w not in visited :
722725 stack .append (w )
723- visited .discard (src )
724- return visited
726+
727+ return visited - starts
728+
729+ def _reachable_from (
730+ self ,
731+ src : GenericCharacteristicName ,
732+ * ,
733+ allowed : set [GenericCharacteristicName ] | None = None ,
734+ ) -> set [GenericCharacteristicName ]:
735+ """
736+ Compute forward reachable nodes from src.
737+
738+ Parameters
739+ ----------
740+ src : str
741+ Starting node.
742+ allowed : set of str, optional
743+ Restrict to this set of nodes.
744+
745+ Returns
746+ -------
747+ set of str
748+ Nodes reachable from src (excluding src itself).
749+ """
750+ return self ._reachable_from_many ({src }, allowed = allowed )
725751
726752 @staticmethod
727753 def _pick_method (
@@ -747,5 +773,5 @@ def _pick_method(
747773 return variants [prefer_label ].method
748774 if DEFAULT_COMPUTATION_KEY in variants :
749775 return variants [DEFAULT_COMPUTATION_KEY ].method
750- label = sorted (variants . keys ())[ 0 ]
776+ label = min (variants )
751777 return variants [label ].method
0 commit comments