@@ -97,6 +97,7 @@ def __init__(self, source, context=None) -> None:
9797 self ._cascade_restrictions = copy_module .deepcopy (source ._cascade_restrictions )
9898 self ._restrict_conditions = copy_module .deepcopy (source ._restrict_conditions )
9999 self ._restriction_attrs = copy_module .deepcopy (source ._restriction_attrs )
100+ self ._part_integrity = getattr (source , "_part_integrity" , "enforce" )
100101 super ().__init__ (source )
101102 return
102103
@@ -369,6 +370,7 @@ def cascade(self, table_expr, part_integrity="enforce"):
369370 "cascade and restrict modes are mutually exclusive."
370371 )
371372 result = Diagram (self )
373+ result ._part_integrity = part_integrity
372374 node = table_expr .full_table_name
373375 if node not in result .nodes ():
374376 raise DataJointError (f"Table { node } is not in the diagram." )
@@ -380,6 +382,34 @@ def cascade(self, table_expr, part_integrity="enforce"):
380382 result ._propagate_restrictions (node , mode = "cascade" , part_integrity = part_integrity )
381383 return result
382384
385+ @staticmethod
386+ def _restrict_freetable (ft , restrictions , mode = "cascade" ):
387+ """
388+ Apply cascade/restrict restrictions to a FreeTable.
389+
390+ Uses ``restrict()`` to properly convert each restriction (AndList,
391+ QueryExpression, etc.) into SQL via ``make_condition``, rather than
392+ assigning raw objects to ``_restriction`` which would produce
393+ invalid SQL in ``where_clause``.
394+
395+ For cascade mode (delete), restrictions from different parent edges
396+ are OR-ed: a row is deleted if ANY of its FK references point to a
397+ deleted row.
398+
399+ For restrict mode (export), restrictions are AND-ed: a row is
400+ included only if ALL ancestor conditions are satisfied.
401+ """
402+ if not restrictions :
403+ return ft
404+ if mode == "cascade" :
405+ # OR semantics — passing a list to restrict() creates an OrList
406+ return ft .restrict (restrictions )
407+ else :
408+ # AND semantics — each restriction narrows further
409+ for r in restrictions :
410+ ft = ft .restrict (r )
411+ return ft
412+
383413 def restrict (self , table_expr ):
384414 """
385415 Apply restrict condition and propagate downstream.
@@ -445,11 +475,8 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
445475 # Build parent FreeTable with current restriction
446476 parent_ft = FreeTable (self ._connection , node )
447477 restr = restrictions [node ]
448- if mode == "cascade" and restr :
449- parent_ft ._restriction = restr # plain list → OR
450- elif mode == "restrict" :
451- parent_ft ._restriction = restr # AndList → AND
452- # else: cascade with empty list → unrestricted
478+ if restr :
479+ parent_ft = self ._restrict_freetable (parent_ft , restr , mode = mode )
453480
454481 parent_attrs = self ._restriction_attrs .get (node , set ())
455482
@@ -507,14 +534,14 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
507534 child_ft = FreeTable (self ._connection , target )
508535 child_restr = restrictions .get (target , [])
509536 if child_restr :
510- child_ft . _restriction = child_restr
537+ child_ft = self . _restrict_freetable ( child_ft , child_restr , mode = mode )
511538 master_ft = FreeTable (self ._connection , master_name )
512539 from .condition import make_condition
513540
514541 master_restr = make_condition (
515542 master_ft ,
516543 (master_ft .proj () & child_ft .proj ()).to_arrays (),
517- master_ft ._restriction_attributes ,
544+ master_ft .restriction_attributes ,
518545 )
519546 restrictions [master_name ] = [master_restr ]
520547 self ._restriction_attrs [master_name ] = set ()
@@ -579,7 +606,7 @@ def _apply_propagation_rule(
579606
580607 self ._restriction_attrs .setdefault (child_node , set ()).update (child_attrs )
581608
582- def delete (self , transaction = True , prompt = None ):
609+ def delete (self , transaction = True , prompt = None , dry_run = False ):
583610 """
584611 Execute cascading delete using cascade restrictions.
585612
@@ -589,14 +616,20 @@ def delete(self, transaction=True, prompt=None):
589616 Wrap in a transaction. Default True.
590617 prompt : bool or None, optional
591618 Show preview and ask confirmation. Default ``dj.config['safemode']``.
619+ dry_run : bool, optional
620+ If True, return affected row counts without deleting. Default False.
592621
593622 Returns
594623 -------
595- int
596- Number of rows deleted from the root table.
624+ int or dict[str, int]
625+ Number of rows deleted from the root table, or (if ``dry_run``)
626+ a mapping of full table name to affected row count.
597627 """
598628 from .table import FreeTable
599629
630+ if dry_run :
631+ return self .preview ()
632+
600633 prompt = self ._connection ._config ["safemode" ] if prompt is None else prompt
601634
602635 if not self ._cascade_restrictions :
@@ -606,14 +639,15 @@ def delete(self, transaction=True, prompt=None):
606639
607640 # Pre-check part_integrity="enforce": ensure no part is deleted
608641 # before its master
609- for node in self ._cascade_restrictions :
610- master = extract_master (node )
611- if master and master not in self ._cascade_restrictions :
612- raise DataJointError (
613- f"Attempt to delete part table { node } before "
614- f"its master { master } . Delete from the master first, "
615- f"or use part_integrity='ignore' or 'cascade'."
616- )
642+ if getattr (self , "_part_integrity" , "enforce" ) == "enforce" :
643+ for node in self ._cascade_restrictions :
644+ master = extract_master (node )
645+ if master and master not in self ._cascade_restrictions :
646+ raise DataJointError (
647+ f"Attempt to delete part table { node } before "
648+ f"its master { master } . Delete from the master first, "
649+ f"or use part_integrity='ignore' or 'cascade'."
650+ )
617651
618652 # Get non-alias nodes with restrictions in topological order
619653 all_sorted = topo_sort (self )
@@ -623,9 +657,7 @@ def delete(self, transaction=True, prompt=None):
623657 if prompt :
624658 for t in tables :
625659 ft = FreeTable (conn , t )
626- restr = self ._cascade_restrictions [t ]
627- if restr :
628- ft ._restriction = restr
660+ ft = self ._restrict_freetable (ft , self ._cascade_restrictions [t ])
629661 logger .info ("{table} ({count} tuples)" .format (table = t , count = len (ft )))
630662
631663 # Start transaction
@@ -647,9 +679,7 @@ def delete(self, transaction=True, prompt=None):
647679 try :
648680 for table_name in reversed (tables ):
649681 ft = FreeTable (conn , table_name )
650- restr = self ._cascade_restrictions [table_name ]
651- if restr :
652- ft ._restriction = restr
682+ ft = self ._restrict_freetable (ft , self ._cascade_restrictions [table_name ])
653683 count = ft .delete_quick (get_count = True )
654684 logger .info ("Deleting {count} rows from {table}" .format (count = count , table = table_name ))
655685 if table_name == tables [0 ]:
@@ -692,7 +722,7 @@ def delete(self, transaction=True, prompt=None):
692722 root_count = 0
693723 return root_count
694724
695- def drop (self , prompt = None , part_integrity = "enforce" ):
725+ def drop (self , prompt = None , part_integrity = "enforce" , dry_run = False ):
696726 """
697727 Drop all tables in the diagram in reverse topological order.
698728
@@ -702,6 +732,13 @@ def drop(self, prompt=None, part_integrity="enforce"):
702732 Show preview and ask confirmation. Default ``dj.config['safemode']``.
703733 part_integrity : str, optional
704734 ``"enforce"`` (default) or ``"ignore"``.
735+ dry_run : bool, optional
736+ If True, return row counts without dropping. Default False.
737+
738+ Returns
739+ -------
740+ dict[str, int] or None
741+ If ``dry_run``, mapping of full table name to row count.
705742 """
706743 from .table import FreeTable
707744
@@ -720,6 +757,14 @@ def drop(self, prompt=None, part_integrity="enforce"):
720757 )
721758 )
722759
760+ if dry_run :
761+ result = {}
762+ for t in tables :
763+ count = len (FreeTable (conn , t ))
764+ result [t ] = count
765+ logger .info ("{table} ({count} tuples)" .format (table = t , count = count ))
766+ return result
767+
723768 do_drop = True
724769 if prompt :
725770 for t in tables :
@@ -742,6 +787,7 @@ def preview(self):
742787 from .table import FreeTable
743788
744789 restrictions = self ._cascade_restrictions or self ._restrict_conditions
790+ mode = "cascade" if self ._cascade_restrictions else "restrict"
745791 if not restrictions :
746792 raise DataJointError ("No restrictions applied. " "Call cascade() or restrict() first." )
747793
@@ -750,9 +796,7 @@ def preview(self):
750796 if node .isdigit () or node not in restrictions :
751797 continue
752798 ft = FreeTable (self ._connection , node )
753- restr = restrictions [node ]
754- if restr :
755- ft ._restriction = restr
799+ ft = self ._restrict_freetable (ft , restrictions [node ], mode = mode )
756800 result [node ] = len (ft )
757801
758802 for t , count in result .items ():
0 commit comments