@@ -100,6 +100,7 @@ def __init__(self, source, context=None) -> None:
100100 self ._restrict_conditions = copy_module .deepcopy (source ._restrict_conditions )
101101 self ._restriction_attrs = copy_module .deepcopy (source ._restriction_attrs )
102102 self ._part_integrity = source ._part_integrity
103+ self ._source = getattr (source , "_source" , None )
103104 super ().__init__ (source )
104105 return
105106
@@ -131,6 +132,7 @@ def __init__(self, source, context=None) -> None:
131132
132133 # Enumerate nodes from all the items in the list
133134 self .nodes_to_show = set ()
135+ self ._source = None
134136 try :
135137 self .nodes_to_show .add (source .full_table_name )
136138 except AttributeError :
@@ -165,39 +167,6 @@ def from_sequence(cls, sequence) -> "Diagram":
165167 """
166168 return functools .reduce (lambda x , y : x + y , map (Diagram , sequence ))
167169
168- @classmethod
169- def _from_table (cls , table_expr ) -> "Diagram" :
170- """
171- Create a Diagram containing table_expr and all its descendants.
172-
173- Internal factory for ``Table.delete()`` and ``Table.drop()``.
174- Bypasses the normal ``__init__`` which does caller-frame introspection
175- and source-type resolution.
176-
177- Parameters
178- ----------
179- table_expr : Table
180- A table instance with ``connection`` and ``full_table_name``.
181-
182- Returns
183- -------
184- Diagram
185- """
186- conn = table_expr .connection
187- conn .dependencies .load ()
188- descendants = set (conn .dependencies .descendants (table_expr .full_table_name ))
189- result = cls .__new__ (cls )
190- nx .DiGraph .__init__ (result , conn .dependencies )
191- result ._connection = conn
192- result .context = {}
193- result .nodes_to_show = descendants
194- result ._expanded_nodes = set (descendants )
195- result ._cascade_restrictions = {}
196- result ._restrict_conditions = {}
197- result ._restriction_attrs = {}
198- result ._part_integrity = "enforce"
199- return result
200-
201170 def add_parts (self ) -> "Diagram" :
202171 """
203172 Add part tables of all masters already in the diagram.
@@ -347,45 +316,65 @@ def topo_sort(self) -> list[str]:
347316 """
348317 return topo_sort (self )
349318
350- def cascade (self , table_expr , part_integrity = "enforce" ):
319+ @classmethod
320+ def cascade (cls , table_expr , part_integrity = "enforce" ):
351321 """
352- Apply cascade restriction and propagate downstream.
353-
354- OR at convergence — a child row is affected if *any* restricted
355- ancestor taints it. Used for delete.
322+ Create a cascade diagram for a table expression.
356323
357- Can only be called once on an unrestricted Diagram. Cannot be
358- mixed with ``restrict()``.
324+ Builds a Diagram from the table's dependency graph, includes all
325+ descendants (across all loaded schemas), and propagates the
326+ restriction downstream using OR convergence — a child row is
327+ affected if *any* restricted ancestor taints it.
359328
360329 Parameters
361330 ----------
362331 table_expr : QueryExpression
363- A restricted table expression
332+ A (possibly restricted) table expression
364333 (e.g., ``Session & 'subject_id=1'``).
365334 part_integrity : str, optional
366335 ``"enforce"`` (default), ``"ignore"``, or ``"cascade"``.
367336
368337 Returns
369338 -------
370339 Diagram
371- New Diagram with cascade restrictions applied.
340+ New Diagram with cascade restrictions applied, trimmed to
341+ the seed table and its affected descendants.
342+
343+ Examples
344+ --------
345+ >>> # Preview cascade impact across all downstream schemas
346+ >>> dj.Diagram.cascade(Session & 'subject_id=1').counts()
347+
348+ >>> # Inspect the cascade subgraph
349+ >>> dj.Diagram.cascade(Session & 'subject_id=1')
372350 """
373- if self ._cascade_restrictions or self ._restrict_conditions :
374- raise DataJointError (
375- "cascade() can only be called once on an unrestricted Diagram. "
376- "cascade and restrict modes are mutually exclusive."
377- )
378- result = Diagram (self )
379- result ._part_integrity = part_integrity
351+ conn = table_expr .connection
352+ conn .dependencies .load ()
380353 node = table_expr .full_table_name
381- if node not in result .nodes ():
382- raise DataJointError (f"Table { node } is not in the diagram." )
354+
355+ result = cls .__new__ (cls )
356+ nx .DiGraph .__init__ (result , conn .dependencies )
357+ result ._connection = conn
358+ result .context = {}
359+ result ._cascade_restrictions = {}
360+ result ._restrict_conditions = {}
361+ result ._restriction_attrs = {}
362+ result ._part_integrity = part_integrity
363+ result ._source = table_expr
364+
365+ # Include seed + all descendants
366+ descendants = set (nx .descendants (result , node )) | {node }
367+ result .nodes_to_show = descendants
368+ result ._expanded_nodes = set (descendants )
369+
383370 # Seed restriction
384371 restriction = AndList (table_expr .restriction )
385372 result ._cascade_restrictions [node ] = [restriction ] if restriction else []
386373 result ._restriction_attrs [node ] = set (table_expr .restriction_attributes )
374+
387375 # Propagate downstream
388376 result ._propagate_restrictions (node , mode = "cascade" , part_integrity = part_integrity )
377+
389378 # Trim graph to cascade subgraph: only restricted tables
390379 # (seed + descendants) plus alias nodes connecting them.
391380 keep = set (result ._cascade_restrictions )
0 commit comments