Skip to content

Commit ca63d86

Browse files
committed
feat(registry): added support for hyper edges in the graph
1 parent 93df880 commit ca63d86

2 files changed

Lines changed: 308 additions & 77 deletions

File tree

src/pysatl_core/distributions/registry/graph.py

Lines changed: 98 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
Core concepts:
1212
- **Nodes**: Characteristics (PDF, CDF, etc.) with presence and definitiveness rules
13-
- **Edges**: Unary computation methods between characteristics
13+
- **Edges**: Computation methods from one-or-many sources to one target
1414
- **Constraints**: Rules that determine when nodes/edges are applicable
1515
- **View**: A filtered subgraph for a specific distribution
1616
- **Definitive characteristics**: Starting points for computations
@@ -64,14 +64,14 @@ class CharacteristicRegistry:
6464
add_characteristic(name, is_definitive, presence_constraint=None, definitive_constraint=None)
6565
Declare a characteristic with presence and optional definitiveness rules.
6666
add_computation(method, label=DEFAULT_COMPUTATION_KEY, constraint=None)
67-
Add a unary computation edge between declared nodes.
67+
Add a computation edge between declared nodes.
6868
view(distr)
6969
Create a filtered view for the given distribution.
7070
7171
Notes
7272
-----
7373
- Nodes must be declared before adding computations
74-
- Only unary computations (1 source → 1 target) are supported
74+
- Only many-to-one computations (n sources → 1 target) are supported
7575
- No invariant validation happens during mutation; validation occurs when
7676
creating a view with view()
7777
"""
@@ -88,10 +88,12 @@ def __init__(self) -> None:
8888
if getattr(self, "_initialized", False):
8989
return
9090

91-
# Adjacency: src → dst → label → [EdgeMeta]
91+
# Adjacency projection: src → dst → label → [ComputationEdgeMeta]
92+
# For hyperedges (many sources -> one target), the same edge metadata object
93+
# is projected under each source to preserve graph reachability semantics.
9294
self._adj: dict[
9395
GenericCharacteristicName,
94-
dict[GenericCharacteristicName, dict[LabelName, list[EdgeMeta]]],
96+
dict[GenericCharacteristicName, dict[LabelName, list[ComputationEdgeMeta]]],
9597
] = {}
9698
self._all_nodes: set[GenericCharacteristicName] = set()
9799

@@ -179,12 +181,12 @@ def add_computation(
179181
constraint: GraphPrimitiveConstraint | None = None,
180182
) -> None:
181183
"""
182-
Add a labeled unary computation edge.
184+
Add a labeled computation edge.
183185
184186
Parameters
185187
----------
186188
method : ComputationMethod
187-
Computation object with exactly one source and one target.
189+
Computation object with one-or-many sources and one target.
188190
label : LabelName, default=DEFAULT_COMPUTATION_KEY
189191
Variant label for the edge.
190192
constraint : GraphPrimitiveConstraint, optional
@@ -193,33 +195,36 @@ def add_computation(
193195
Raises
194196
------
195197
ValueError
196-
If method is not unary, or source/target nodes are not declared.
198+
If method has no sources, or source/target nodes are not declared.
197199
198200
Notes
199201
-----
200202
- Multiple edges with different labels can exist between the same nodes
201203
- The first matching edge for each label is kept when creating views
204+
- Hyperedges are represented as projected edges from each source to target,
205+
while preserving one shared underlying computation method.
202206
"""
203-
if len(method.sources) != 1:
204-
raise ValueError("Only unary computations are supported (1 source → 1 target).")
207+
if not method.sources:
208+
raise ValueError("Computation must define at least one source characteristic.")
205209

206-
src = method.sources[0]
210+
unique_sources = tuple(dict.fromkeys(method.sources))
207211
dst = method.target
208212

209-
if not self._ensure_node(src) or not self._ensure_node(dst):
213+
if not self._ensure_node(dst) or any(not self._ensure_node(src) for src in unique_sources):
210214
raise ValueError("Source characteristic or destination characteristic is invalid.")
211215

212-
self._adj[src].setdefault(dst, {})
216+
edge_meta = ComputationEdgeMeta(
217+
method=method,
218+
constraint=constraint or GraphPrimitiveConstraint(),
219+
)
220+
213221
# TODO: We need to be careful here if some constraint more general and with the same label
214222
# than other it can consume it. Actually, the same label methods should not intersect their
215223
# constraints
216-
self._adj[src][dst].setdefault(label, [])
217-
self._adj[src][dst][label].append(
218-
ComputationEdgeMeta(
219-
method=method,
220-
constraint=constraint or GraphPrimitiveConstraint(),
221-
)
222-
)
224+
for src in unique_sources:
225+
self._adj[src].setdefault(dst, {})
226+
self._adj[src][dst].setdefault(label, [])
227+
self._adj[src][dst][label].append(edge_meta)
223228

224229
def add_characteristic(
225230
self,
@@ -342,41 +347,36 @@ def view(self, distr: Distribution) -> RegistryView:
342347
343348
Notes
344349
-----
345-
1. Filters edges by their constraints
346-
2. Removes edges touching absent nodes
350+
1. Computes present nodes for the distribution
351+
2. Filters edges by node presence and edge constraints
347352
3. Adds analytical self-loops from distribution analytical computations
348353
4. Computes definitive nodes from the remaining present nodes
349354
5. Validates graph invariants
350355
"""
351-
# 1) Filter edges by applicability
356+
# 1) Compute present nodes once and pre-create adjacency.
357+
present_nodes = self._compute_present_nodes(distr)
352358
adj: dict[
353359
GenericCharacteristicName, dict[GenericCharacteristicName, dict[LabelName, EdgeMeta]]
354-
] = {}
355-
for src, d in self._adj.items():
356-
for dst, variants in d.items():
360+
] = {node: {} for node in present_nodes}
361+
362+
# 2) Filter edges by node presence and applicability.
363+
for src in present_nodes:
364+
for dst, variants in self._adj.get(src, {}).items():
365+
if dst not in present_nodes:
366+
continue
357367
kept: dict[LabelName, EdgeMeta] = {}
358368
for label, metas in variants.items():
359369
for edge in metas:
360-
if edge.constraint.allows(distr):
370+
if edge.constraint.allows(distr) and all(
371+
source in present_nodes for source in edge.method.sources
372+
):
361373
kept[label] = edge
362374
# TODO: It is possible that there are two edges under the same label
363375
# that fit the same distribution, this should not be the case.
364376
# Taking the first one for now
365377
break
366378
if kept:
367-
adj.setdefault(src, {}).setdefault(dst, {}).update(kept)
368-
369-
# 2) Filter by node presence
370-
present_nodes = self._compute_present_nodes(distr)
371-
if present_nodes:
372-
adj = {
373-
src: {dst: dict(variants) for dst, variants in d.items() if dst in present_nodes}
374-
for src, d in adj.items()
375-
if src in present_nodes
376-
}
377-
# Ensure isolated present nodes are preserved
378-
for node in present_nodes:
379-
adj.setdefault(node, {})
379+
adj[src][dst] = kept
380380

381381
# 3) Attach analytical loops
382382
self._attach_analytical_loops(adj, distr, present_nodes)
@@ -438,6 +438,14 @@ def __init__(
438438
for src, d in adj.items():
439439
self._adj[src] = {dst: dict(variants) for dst, variants in d.items()}
440440

441+
self._rev_adj: dict[GenericCharacteristicName, set[GenericCharacteristicName]] = {
442+
node: set() for node in self._adj
443+
}
444+
for src, d in self._adj.items():
445+
for dst, variants in d.items():
446+
if variants:
447+
self._rev_adj.setdefault(dst, set()).add(src)
448+
441449
self.definitive_characteristics: set[GenericCharacteristicName] = set(definitive_nodes)
442450
self.all_characteristics: set[GenericCharacteristicName] = set(present_nodes)
443451

@@ -504,11 +512,7 @@ def predecessors(self, v: GenericCharacteristicName) -> set[GenericCharacteristi
504512
set of str
505513
Characteristics that can reach v directly.
506514
"""
507-
res: set[GenericCharacteristicName] = set()
508-
for src, d in self._adj.items():
509-
if v in d and d[v]:
510-
res.add(src)
511-
return res
515+
return set(self._rev_adj.get(v, set()))
512516

513517
def variants(
514518
self, src: GenericCharacteristicName, dst: GenericCharacteristicName
@@ -643,16 +647,8 @@ def _definitive_strongly_connected(self) -> bool:
643647
if fwd != (defs - {start}):
644648
return False
645649

646-
# Check reverse reachability
647-
seen: set[GenericCharacteristicName] = {start}
648-
stack = [start]
649-
while stack:
650-
v = stack.pop()
651-
for w in self.predecessors(v):
652-
if w in defs and w not in seen:
653-
seen.add(w)
654-
stack.append(w)
655-
return seen == defs
650+
rev = self._reachable_from_many({start}, allowed=defs, reverse=True)
651+
return rev == (defs - {start})
656652

657653
def _all_indefinitives_reachable_from_definitives(self) -> bool:
658654
"""
@@ -667,9 +663,7 @@ def _all_indefinitives_reachable_from_definitives(self) -> bool:
667663
if not indefs:
668664
return True
669665

670-
total: set[GenericCharacteristicName] = set()
671-
for d in self.definitive_characteristics:
672-
total |= self._reachable_from(d)
666+
total = self._reachable_from_many(self.definitive_characteristics)
673667
return indefs.issubset(total)
674668

675669
def _exists_path_from_indefinitive_to_definitive(self) -> bool:
@@ -682,46 +676,78 @@ def _exists_path_from_indefinitive_to_definitive(self) -> bool:
682676
True if such a path exists (which would violate invariants).
683677
"""
684678
defs = self.definitive_characteristics
685-
return any(self._reachable_from(i) & defs for i in self.indefinitive_characteristics)
679+
if not defs:
680+
return False
686681

687-
def _reachable_from(
682+
can_reach_definitive = self._reachable_from_many(defs, reverse=True)
683+
return bool(can_reach_definitive & self.indefinitive_characteristics)
684+
685+
def _reachable_from_many(
688686
self,
689-
src: GenericCharacteristicName,
687+
sources: set[GenericCharacteristicName],
690688
*,
691689
allowed: set[GenericCharacteristicName] | None = None,
690+
reverse: bool = False,
692691
) -> set[GenericCharacteristicName]:
693692
"""
694-
Compute forward reachable nodes from src.
693+
Compute reachable nodes from multiple sources.
695694
696695
Parameters
697696
----------
698-
src : str
699-
Starting node.
697+
sources : set of str
698+
Starting nodes.
700699
allowed : set of str, optional
701-
Restrict to this set of nodes.
700+
Restrict traversal to this set of nodes.
701+
reverse : bool, default=False
702+
If True, traverse reverse edges.
702703
703704
Returns
704705
-------
705706
set of str
706-
Nodes reachable from src (excluding src itself).
707+
Nodes reachable from sources (excluding sources themselves).
707708
"""
708-
if allowed is not None and src not in allowed:
709+
starts = {s for s in sources if allowed is None or s in allowed}
710+
if not starts:
709711
return set()
710712

711713
visited: set[GenericCharacteristicName] = set()
712-
stack = [src]
714+
stack = list(starts)
713715
while stack:
714716
v = stack.pop()
715717
if v in visited:
716718
continue
717719
visited.add(v)
718-
for w in self.successors_nodes(v):
720+
neighbors = self._rev_adj.get(v, set()) if reverse else self.successors_nodes(v)
721+
for w in neighbors:
719722
if allowed is not None and w not in allowed:
720723
continue
721724
if w not in visited:
722725
stack.append(w)
723-
visited.discard(src)
724-
return visited
726+
727+
return visited - starts
728+
729+
def _reachable_from(
730+
self,
731+
src: GenericCharacteristicName,
732+
*,
733+
allowed: set[GenericCharacteristicName] | None = None,
734+
) -> set[GenericCharacteristicName]:
735+
"""
736+
Compute forward reachable nodes from src.
737+
738+
Parameters
739+
----------
740+
src : str
741+
Starting node.
742+
allowed : set of str, optional
743+
Restrict to this set of nodes.
744+
745+
Returns
746+
-------
747+
set of str
748+
Nodes reachable from src (excluding src itself).
749+
"""
750+
return self._reachable_from_many({src}, allowed=allowed)
725751

726752
@staticmethod
727753
def _pick_method(
@@ -747,5 +773,5 @@ def _pick_method(
747773
return variants[prefer_label].method
748774
if DEFAULT_COMPUTATION_KEY in variants:
749775
return variants[DEFAULT_COMPUTATION_KEY].method
750-
label = sorted(variants.keys())[0]
776+
label = min(variants)
751777
return variants[label].method

0 commit comments

Comments
 (0)