Skip to content

Commit 923e01f

Browse files
refactor: make Diagram.cascade() a classmethod factory
The cascade preview pattern is now a single call: dj.Diagram.cascade(Session & 'subject_id=1').counts() cascade() constructs the Diagram directly from the table expression, includes all descendants (cross-schema), propagates restrictions, and trims to the affected subgraph. Table.delete() uses Diagram.cascade(self, ...) internally. Table.drop() expands descendants inline via nx.descendants(). Removes _from_table() — no longer needed. Also removes dry_run from delete() and drop() since safemode's transaction + rollback provides a safer preview, and Diagram.cascade().counts() provides programmatic preview. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3d239f1 commit 923e01f

File tree

4 files changed

+50
-60
lines changed

4 files changed

+50
-60
lines changed

src/datajoint/diagram.py

Lines changed: 40 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self, source, context=None) -> None:
100100
self._restrict_conditions = copy_module.deepcopy(source._restrict_conditions)
101101
self._restriction_attrs = copy_module.deepcopy(source._restriction_attrs)
102102
self._part_integrity = source._part_integrity
103+
self._source = getattr(source, "_source", None)
103104
super().__init__(source)
104105
return
105106

@@ -131,6 +132,7 @@ def __init__(self, source, context=None) -> None:
131132

132133
# Enumerate nodes from all the items in the list
133134
self.nodes_to_show = set()
135+
self._source = None
134136
try:
135137
self.nodes_to_show.add(source.full_table_name)
136138
except AttributeError:
@@ -165,39 +167,6 @@ def from_sequence(cls, sequence) -> "Diagram":
165167
"""
166168
return functools.reduce(lambda x, y: x + y, map(Diagram, sequence))
167169

168-
@classmethod
169-
def _from_table(cls, table_expr) -> "Diagram":
170-
"""
171-
Create a Diagram containing table_expr and all its descendants.
172-
173-
Internal factory for ``Table.delete()`` and ``Table.drop()``.
174-
Bypasses the normal ``__init__`` which does caller-frame introspection
175-
and source-type resolution.
176-
177-
Parameters
178-
----------
179-
table_expr : Table
180-
A table instance with ``connection`` and ``full_table_name``.
181-
182-
Returns
183-
-------
184-
Diagram
185-
"""
186-
conn = table_expr.connection
187-
conn.dependencies.load()
188-
descendants = set(conn.dependencies.descendants(table_expr.full_table_name))
189-
result = cls.__new__(cls)
190-
nx.DiGraph.__init__(result, conn.dependencies)
191-
result._connection = conn
192-
result.context = {}
193-
result.nodes_to_show = descendants
194-
result._expanded_nodes = set(descendants)
195-
result._cascade_restrictions = {}
196-
result._restrict_conditions = {}
197-
result._restriction_attrs = {}
198-
result._part_integrity = "enforce"
199-
return result
200-
201170
def add_parts(self) -> "Diagram":
202171
"""
203172
Add part tables of all masters already in the diagram.
@@ -347,45 +316,65 @@ def topo_sort(self) -> list[str]:
347316
"""
348317
return topo_sort(self)
349318

350-
def cascade(self, table_expr, part_integrity="enforce"):
319+
@classmethod
320+
def cascade(cls, table_expr, part_integrity="enforce"):
351321
"""
352-
Apply cascade restriction and propagate downstream.
353-
354-
OR at convergence — a child row is affected if *any* restricted
355-
ancestor taints it. Used for delete.
322+
Create a cascade diagram for a table expression.
356323
357-
Can only be called once on an unrestricted Diagram. Cannot be
358-
mixed with ``restrict()``.
324+
Builds a Diagram from the table's dependency graph, includes all
325+
descendants (across all loaded schemas), and propagates the
326+
restriction downstream using OR convergence — a child row is
327+
affected if *any* restricted ancestor taints it.
359328
360329
Parameters
361330
----------
362331
table_expr : QueryExpression
363-
A restricted table expression
332+
A (possibly restricted) table expression
364333
(e.g., ``Session & 'subject_id=1'``).
365334
part_integrity : str, optional
366335
``"enforce"`` (default), ``"ignore"``, or ``"cascade"``.
367336
368337
Returns
369338
-------
370339
Diagram
371-
New Diagram with cascade restrictions applied.
340+
New Diagram with cascade restrictions applied, trimmed to
341+
the seed table and its affected descendants.
342+
343+
Examples
344+
--------
345+
>>> # Preview cascade impact across all downstream schemas
346+
>>> dj.Diagram.cascade(Session & 'subject_id=1').counts()
347+
348+
>>> # Inspect the cascade subgraph
349+
>>> dj.Diagram.cascade(Session & 'subject_id=1')
372350
"""
373-
if self._cascade_restrictions or self._restrict_conditions:
374-
raise DataJointError(
375-
"cascade() can only be called once on an unrestricted Diagram. "
376-
"cascade and restrict modes are mutually exclusive."
377-
)
378-
result = Diagram(self)
379-
result._part_integrity = part_integrity
351+
conn = table_expr.connection
352+
conn.dependencies.load()
380353
node = table_expr.full_table_name
381-
if node not in result.nodes():
382-
raise DataJointError(f"Table {node} is not in the diagram.")
354+
355+
result = cls.__new__(cls)
356+
nx.DiGraph.__init__(result, conn.dependencies)
357+
result._connection = conn
358+
result.context = {}
359+
result._cascade_restrictions = {}
360+
result._restrict_conditions = {}
361+
result._restriction_attrs = {}
362+
result._part_integrity = part_integrity
363+
result._source = table_expr
364+
365+
# Include seed + all descendants
366+
descendants = set(nx.descendants(result, node)) | {node}
367+
result.nodes_to_show = descendants
368+
result._expanded_nodes = set(descendants)
369+
383370
# Seed restriction
384371
restriction = AndList(table_expr.restriction)
385372
result._cascade_restrictions[node] = [restriction] if restriction else []
386373
result._restriction_attrs[node] = set(table_expr.restriction_attributes)
374+
387375
# Propagate downstream
388376
result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity)
377+
389378
# Trim graph to cascade subgraph: only restricted tables
390379
# (seed + descendants) plus alias nodes connecting them.
391380
keep = set(result._cascade_restrictions)

src/datajoint/table.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -996,8 +996,7 @@ def delete(
996996
997997
To preview cascade impact without executing, use ``Diagram``::
998998
999-
diag = dj.Diagram(schema)
1000-
diag.cascade(MyTable & restriction).counts()
999+
dj.Diagram.cascade(MyTable & restriction).counts()
10011000
10021001
Args:
10031002
transaction: If `True`, use of the entire delete becomes an atomic transaction.
@@ -1022,8 +1021,7 @@ def delete(
10221021
raise ValueError(f"part_integrity must be 'enforce', 'ignore', or 'cascade', " f"got {part_integrity!r}")
10231022
from .diagram import Diagram
10241023

1025-
diagram = Diagram._from_table(self)
1026-
diagram = diagram.cascade(self, part_integrity=part_integrity)
1024+
diagram = Diagram.cascade(self, part_integrity=part_integrity)
10271025

10281026
conn = self.connection
10291027
prompt = conn._config["safemode"] if prompt is None else prompt
@@ -1169,9 +1167,14 @@ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
11691167
raise DataJointError(
11701168
"A table with an applied restriction cannot be dropped. " "Call drop() on the unrestricted Table."
11711169
)
1170+
import networkx as nx
11721171
from .diagram import Diagram
11731172

1174-
diagram = Diagram._from_table(self)
1173+
diagram = Diagram(self)
1174+
# Expand to include all descendants (cross-schema)
1175+
descendants = set(nx.descendants(diagram, self.full_table_name)) | {self.full_table_name}
1176+
diagram.nodes_to_show = descendants
1177+
diagram._expanded_nodes = set(descendants)
11751178
conn = self.connection
11761179
prompt = conn._config["safemode"] if prompt is None else prompt
11771180

tests/integration/test_cascade_delete.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ class Child(dj.Manual):
217217
Child.insert1((2, 1, "C2-1"))
218218

219219
# Preview restricted cascade via Diagram
220-
diag = dj.Diagram._from_table(Parent & {"parent_id": 1})
221-
counts = diag.cascade(Parent & {"parent_id": 1}).counts()
220+
counts = dj.Diagram.cascade(Parent & {"parent_id": 1}).counts()
222221

223222
assert isinstance(counts, dict)
224223
assert counts[Parent.full_table_name] == 1

tests/integration/test_erd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def test_prune_after_restrict(schema_simp_pop):
126126

127127
def test_prune_after_cascade(schema_simp_pop):
128128
"""Prune after cascade removes tables with zero matching rows."""
129-
diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
130-
cascaded = diag.cascade(A & "id_a=0")
129+
cascaded = dj.Diagram.cascade(A & "id_a=0")
131130
counts = cascaded.counts()
132131

133132
pruned = cascaded.prune()

0 commit comments

Comments
 (0)