@@ -97,7 +97,7 @@ def __init__(self, source, context=None) -> None:
9797 self ._cascade_restrictions = copy_module .deepcopy (source ._cascade_restrictions )
9898 self ._restrict_conditions = copy_module .deepcopy (source ._restrict_conditions )
9999 self ._restriction_attrs = copy_module .deepcopy (source ._restriction_attrs )
100- self ._part_integrity = source . _part_integrity
100+ self ._part_integrity = getattr ( source , " _part_integrity" , "enforce" )
101101 super ().__init__ (source )
102102 return
103103
@@ -125,7 +125,6 @@ def __init__(self, source, context=None) -> None:
125125 self ._cascade_restrictions = {}
126126 self ._restrict_conditions = {}
127127 self ._restriction_attrs = {}
128- self ._part_integrity = "enforce"
129128
130129 # Enumerate nodes from all the items in the list
131130 self .nodes_to_show = set ()
@@ -193,7 +192,6 @@ def _from_table(cls, table_expr) -> "Diagram":
193192 result ._cascade_restrictions = {}
194193 result ._restrict_conditions = {}
195194 result ._restriction_attrs = {}
196- result ._part_integrity = "enforce"
197195 return result
198196
199197 def add_parts (self ) -> "Diagram" :
@@ -384,6 +382,34 @@ def cascade(self, table_expr, part_integrity="enforce"):
384382 result ._propagate_restrictions (node , mode = "cascade" , part_integrity = part_integrity )
385383 return result
386384
385+ @staticmethod
386+ def _restrict_freetable (ft , restrictions , mode = "cascade" ):
387+ """
388+ Apply cascade/restrict restrictions to a FreeTable.
389+
390+ Uses ``restrict()`` to properly convert each restriction (AndList,
391+ QueryExpression, etc.) into SQL via ``make_condition``, rather than
392+ assigning raw objects to ``_restriction`` which would produce
393+ invalid SQL in ``where_clause``.
394+
395+ For cascade mode (delete), restrictions from different parent edges
396+ are OR-ed: a row is deleted if ANY of its FK references point to a
397+ deleted row.
398+
399+ For restrict mode (export), restrictions are AND-ed: a row is
400+ included only if ALL ancestor conditions are satisfied.
401+ """
402+ if not restrictions :
403+ return ft
404+ if mode == "cascade" :
405+ # OR semantics — passing a list to restrict() creates an OrList
406+ return ft .restrict (restrictions )
407+ else :
408+ # AND semantics — each restriction narrows further
409+ for r in restrictions :
410+ ft = ft .restrict (r )
411+ return ft
412+
387413 def restrict (self , table_expr ):
388414 """
389415 Apply restrict condition and propagate downstream.
@@ -450,10 +476,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
450476 parent_ft = FreeTable (self ._connection , node )
451477 restr = restrictions [node ]
452478 if restr :
453- if mode == "cascade" :
454- parent_ft .restrict_in_place (restr ) # list → OR
455- else :
456- parent_ft ._restriction = restr # AndList → AND
479+ parent_ft = self ._restrict_freetable (parent_ft , restr , mode = mode )
457480
458481 parent_attrs = self ._restriction_attrs .get (node , set ())
459482
@@ -511,7 +534,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
511534 child_ft = FreeTable (self ._connection , target )
512535 child_restr = restrictions .get (target , [])
513536 if child_restr :
514- child_ft . restrict_in_place ( child_restr )
537+ child_ft = self . _restrict_freetable ( child_ft , child_restr , mode = mode )
515538 master_ft = FreeTable (self ._connection , master_name )
516539 from .condition import make_condition
517540
@@ -583,7 +606,7 @@ def _apply_propagation_rule(
583606
584607 self ._restriction_attrs .setdefault (child_node , set ()).update (child_attrs )
585608
586- def delete (self , transaction = True , prompt = None ):
609+ def delete (self , transaction = True , prompt = None , dry_run = False ):
587610 """
588611 Execute cascading delete using cascade restrictions.
589612
@@ -593,14 +616,20 @@ def delete(self, transaction=True, prompt=None):
593616 Wrap in a transaction. Default True.
594617 prompt : bool or None, optional
595618 Show preview and ask confirmation. Default ``dj.config['safemode']``.
619+ dry_run : bool, optional
620+ If True, return affected row counts without deleting. Default False.
596621
597622 Returns
598623 -------
599- int
600- Number of rows deleted from the root table.
624+ int or dict[str, int]
625+ Number of rows deleted from the root table, or (if ``dry_run``)
626+ a mapping of full table name to affected row count.
601627 """
602628 from .table import FreeTable
603629
630+ if dry_run :
631+ return self .preview ()
632+
604633 prompt = self ._connection ._config ["safemode" ] if prompt is None else prompt
605634
606635 if not self ._cascade_restrictions :
@@ -616,9 +645,7 @@ def delete(self, transaction=True, prompt=None):
616645 if prompt :
617646 for t in tables :
618647 ft = FreeTable (conn , t )
619- restr = self ._cascade_restrictions [t ]
620- if restr :
621- ft .restrict_in_place (restr )
648+ ft = self ._restrict_freetable (ft , self ._cascade_restrictions [t ])
622649 logger .info ("{table} ({count} tuples)" .format (table = t , count = len (ft )))
623650
624651 # Start transaction
@@ -641,9 +668,7 @@ def delete(self, transaction=True, prompt=None):
641668 try :
642669 for table_name in reversed (tables ):
643670 ft = FreeTable (conn , table_name )
644- restr = self ._cascade_restrictions [table_name ]
645- if restr :
646- ft .restrict_in_place (restr )
671+ ft = self ._restrict_freetable (ft , self ._cascade_restrictions [table_name ])
647672 count = ft .delete_quick (get_count = True )
648673 if count > 0 :
649674 deleted_tables .add (table_name )
@@ -668,7 +693,7 @@ def delete(self, transaction=True, prompt=None):
668693
669694 # Post-check part_integrity="enforce": roll back if a part table
670695 # had rows deleted without its master also having rows deleted.
671- if self . _part_integrity == "enforce" and deleted_tables :
696+ if getattr ( self , " _part_integrity" , "enforce" ) == "enforce" and deleted_tables :
672697 for table_name in deleted_tables :
673698 master = extract_master (table_name )
674699 if master and master not in deleted_tables :
@@ -702,7 +727,7 @@ def delete(self, transaction=True, prompt=None):
702727 root_count = 0
703728 return root_count
704729
705- def drop (self , prompt = None , part_integrity = "enforce" ):
730+ def drop (self , prompt = None , part_integrity = "enforce" , dry_run = False ):
706731 """
707732 Drop all tables in the diagram in reverse topological order.
708733
@@ -712,6 +737,13 @@ def drop(self, prompt=None, part_integrity="enforce"):
712737 Show preview and ask confirmation. Default ``dj.config['safemode']``.
713738 part_integrity : str, optional
714739 ``"enforce"`` (default) or ``"ignore"``.
740+ dry_run : bool, optional
741+ If True, return row counts without dropping. Default False.
742+
743+ Returns
744+ -------
745+ dict[str, int] or None
746+ If ``dry_run``, mapping of full table name to row count.
715747 """
716748 from .table import FreeTable
717749
@@ -730,6 +762,14 @@ def drop(self, prompt=None, part_integrity="enforce"):
730762 )
731763 )
732764
765+ if dry_run :
766+ result = {}
767+ for t in tables :
768+ count = len (FreeTable (conn , t ))
769+ result [t ] = count
770+ logger .info ("{table} ({count} tuples)" .format (table = t , count = count ))
771+ return result
772+
733773 do_drop = True
734774 if prompt :
735775 for t in tables :
@@ -752,6 +792,7 @@ def preview(self):
752792 from .table import FreeTable
753793
754794 restrictions = self ._cascade_restrictions or self ._restrict_conditions
795+ mode = "cascade" if self ._cascade_restrictions else "restrict"
755796 if not restrictions :
756797 raise DataJointError ("No restrictions applied. " "Call cascade() or restrict() first." )
757798
@@ -760,12 +801,7 @@ def preview(self):
760801 if node .isdigit () or node not in restrictions :
761802 continue
762803 ft = FreeTable (self ._connection , node )
763- restr = restrictions [node ]
764- if restr :
765- if isinstance (restr , list ) and not isinstance (restr , AndList ):
766- ft .restrict_in_place (restr ) # cascade: list → OR
767- else :
768- ft ._restriction = restr # restrict: AndList → AND
804+ ft = self ._restrict_freetable (ft , restrictions [node ], mode = mode )
769805 result [node ] = len (ft )
770806
771807 for t , count in result .items ():
@@ -789,19 +825,15 @@ def prune(self):
789825
790826 result = Diagram (self )
791827 restrictions = result ._cascade_restrictions or result ._restrict_conditions
828+ mode = "cascade" if result ._cascade_restrictions else "restrict"
792829
793830 if restrictions :
794831 # Restricted: check row counts under restriction
795832 for node in list (restrictions ):
796833 if node .isdigit ():
797834 continue
798835 ft = FreeTable (self ._connection , node )
799- restr = restrictions [node ]
800- if restr :
801- if isinstance (restr , list ) and not isinstance (restr , AndList ):
802- ft .restrict_in_place (restr )
803- else :
804- ft ._restriction = restr
836+ ft = self ._restrict_freetable (ft , restrictions [node ], mode = mode )
805837 if len (ft ) == 0 :
806838 restrictions .pop (node )
807839 result ._restriction_attrs .pop (node , None )
0 commit comments