Skip to content

Commit 204745a

Browse files
fix: cascade delete with proper SQL generation, OR convergence, and post-check part integrity
Replace direct `_restriction` assignment with `restrict()` calls in Diagram so that AndList and QueryExpression objects are converted to valid SQL via `make_condition()`. Cascade delete uses OR convergence (a row is deleted if ANY FK reference points to a deleted row), while restrict/export uses AND. Part integrity enforcement uses a data-driven post-check: only raises when rows were actually deleted from a Part without its master also being deleted. This avoids false positives when a Part table appears in the cascade graph but has zero affected rows. Also adds dry_run support to delete()/drop(), prune() method, fixes CLI test subprocess invocation, and updates test fixtures. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b8fd688 commit 204745a

File tree

6 files changed

+1039
-244
lines changed

6 files changed

+1039
-244
lines changed

pixi.lock

Lines changed: 869 additions & 203 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/datajoint/diagram.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, source, context=None) -> None:
9797
self._cascade_restrictions = copy_module.deepcopy(source._cascade_restrictions)
9898
self._restrict_conditions = copy_module.deepcopy(source._restrict_conditions)
9999
self._restriction_attrs = copy_module.deepcopy(source._restriction_attrs)
100-
self._part_integrity = source._part_integrity
100+
self._part_integrity = getattr(source, "_part_integrity", "enforce")
101101
super().__init__(source)
102102
return
103103

@@ -125,7 +125,6 @@ def __init__(self, source, context=None) -> None:
125125
self._cascade_restrictions = {}
126126
self._restrict_conditions = {}
127127
self._restriction_attrs = {}
128-
self._part_integrity = "enforce"
129128

130129
# Enumerate nodes from all the items in the list
131130
self.nodes_to_show = set()
@@ -193,7 +192,6 @@ def _from_table(cls, table_expr) -> "Diagram":
193192
result._cascade_restrictions = {}
194193
result._restrict_conditions = {}
195194
result._restriction_attrs = {}
196-
result._part_integrity = "enforce"
197195
return result
198196

199197
def add_parts(self) -> "Diagram":
@@ -384,6 +382,34 @@ def cascade(self, table_expr, part_integrity="enforce"):
384382
result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity)
385383
return result
386384

385+
@staticmethod
386+
def _restrict_freetable(ft, restrictions, mode="cascade"):
387+
"""
388+
Apply cascade/restrict restrictions to a FreeTable.
389+
390+
Uses ``restrict()`` to properly convert each restriction (AndList,
391+
QueryExpression, etc.) into SQL via ``make_condition``, rather than
392+
assigning raw objects to ``_restriction`` which would produce
393+
invalid SQL in ``where_clause``.
394+
395+
For cascade mode (delete), restrictions from different parent edges
396+
are OR-ed: a row is deleted if ANY of its FK references point to a
397+
deleted row.
398+
399+
For restrict mode (export), restrictions are AND-ed: a row is
400+
included only if ALL ancestor conditions are satisfied.
401+
"""
402+
if not restrictions:
403+
return ft
404+
if mode == "cascade":
405+
# OR semantics — passing a list to restrict() creates an OrList
406+
return ft.restrict(restrictions)
407+
else:
408+
# AND semantics — each restriction narrows further
409+
for r in restrictions:
410+
ft = ft.restrict(r)
411+
return ft
412+
387413
def restrict(self, table_expr):
388414
"""
389415
Apply restrict condition and propagate downstream.
@@ -450,10 +476,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
450476
parent_ft = FreeTable(self._connection, node)
451477
restr = restrictions[node]
452478
if restr:
453-
if mode == "cascade":
454-
parent_ft.restrict_in_place(restr) # list → OR
455-
else:
456-
parent_ft._restriction = restr # AndList → AND
479+
parent_ft = self._restrict_freetable(parent_ft, restr, mode=mode)
457480

458481
parent_attrs = self._restriction_attrs.get(node, set())
459482

@@ -511,7 +534,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
511534
child_ft = FreeTable(self._connection, target)
512535
child_restr = restrictions.get(target, [])
513536
if child_restr:
514-
child_ft.restrict_in_place(child_restr)
537+
child_ft = self._restrict_freetable(child_ft, child_restr, mode=mode)
515538
master_ft = FreeTable(self._connection, master_name)
516539
from .condition import make_condition
517540

@@ -583,7 +606,7 @@ def _apply_propagation_rule(
583606

584607
self._restriction_attrs.setdefault(child_node, set()).update(child_attrs)
585608

586-
def delete(self, transaction=True, prompt=None):
609+
def delete(self, transaction=True, prompt=None, dry_run=False):
587610
"""
588611
Execute cascading delete using cascade restrictions.
589612
@@ -593,14 +616,20 @@ def delete(self, transaction=True, prompt=None):
593616
Wrap in a transaction. Default True.
594617
prompt : bool or None, optional
595618
Show preview and ask confirmation. Default ``dj.config['safemode']``.
619+
dry_run : bool, optional
620+
If True, return affected row counts without deleting. Default False.
596621
597622
Returns
598623
-------
599-
int
600-
Number of rows deleted from the root table.
624+
int or dict[str, int]
625+
Number of rows deleted from the root table, or (if ``dry_run``)
626+
a mapping of full table name to affected row count.
601627
"""
602628
from .table import FreeTable
603629

630+
if dry_run:
631+
return self.preview()
632+
604633
prompt = self._connection._config["safemode"] if prompt is None else prompt
605634

606635
if not self._cascade_restrictions:
@@ -616,9 +645,7 @@ def delete(self, transaction=True, prompt=None):
616645
if prompt:
617646
for t in tables:
618647
ft = FreeTable(conn, t)
619-
restr = self._cascade_restrictions[t]
620-
if restr:
621-
ft.restrict_in_place(restr)
648+
ft = self._restrict_freetable(ft, self._cascade_restrictions[t])
622649
logger.info("{table} ({count} tuples)".format(table=t, count=len(ft)))
623650

624651
# Start transaction
@@ -641,9 +668,7 @@ def delete(self, transaction=True, prompt=None):
641668
try:
642669
for table_name in reversed(tables):
643670
ft = FreeTable(conn, table_name)
644-
restr = self._cascade_restrictions[table_name]
645-
if restr:
646-
ft.restrict_in_place(restr)
671+
ft = self._restrict_freetable(ft, self._cascade_restrictions[table_name])
647672
count = ft.delete_quick(get_count=True)
648673
if count > 0:
649674
deleted_tables.add(table_name)
@@ -668,7 +693,7 @@ def delete(self, transaction=True, prompt=None):
668693

669694
# Post-check part_integrity="enforce": roll back if a part table
670695
# had rows deleted without its master also having rows deleted.
671-
if self._part_integrity == "enforce" and deleted_tables:
696+
if getattr(self, "_part_integrity", "enforce") == "enforce" and deleted_tables:
672697
for table_name in deleted_tables:
673698
master = extract_master(table_name)
674699
if master and master not in deleted_tables:
@@ -702,7 +727,7 @@ def delete(self, transaction=True, prompt=None):
702727
root_count = 0
703728
return root_count
704729

705-
def drop(self, prompt=None, part_integrity="enforce"):
730+
def drop(self, prompt=None, part_integrity="enforce", dry_run=False):
706731
"""
707732
Drop all tables in the diagram in reverse topological order.
708733
@@ -712,6 +737,13 @@ def drop(self, prompt=None, part_integrity="enforce"):
712737
Show preview and ask confirmation. Default ``dj.config['safemode']``.
713738
part_integrity : str, optional
714739
``"enforce"`` (default) or ``"ignore"``.
740+
dry_run : bool, optional
741+
If True, return row counts without dropping. Default False.
742+
743+
Returns
744+
-------
745+
dict[str, int] or None
746+
If ``dry_run``, mapping of full table name to row count.
715747
"""
716748
from .table import FreeTable
717749

@@ -730,6 +762,14 @@ def drop(self, prompt=None, part_integrity="enforce"):
730762
)
731763
)
732764

765+
if dry_run:
766+
result = {}
767+
for t in tables:
768+
count = len(FreeTable(conn, t))
769+
result[t] = count
770+
logger.info("{table} ({count} tuples)".format(table=t, count=count))
771+
return result
772+
733773
do_drop = True
734774
if prompt:
735775
for t in tables:
@@ -752,6 +792,7 @@ def preview(self):
752792
from .table import FreeTable
753793

754794
restrictions = self._cascade_restrictions or self._restrict_conditions
795+
mode = "cascade" if self._cascade_restrictions else "restrict"
755796
if not restrictions:
756797
raise DataJointError("No restrictions applied. " "Call cascade() or restrict() first.")
757798

@@ -760,12 +801,7 @@ def preview(self):
760801
if node.isdigit() or node not in restrictions:
761802
continue
762803
ft = FreeTable(self._connection, node)
763-
restr = restrictions[node]
764-
if restr:
765-
if isinstance(restr, list) and not isinstance(restr, AndList):
766-
ft.restrict_in_place(restr) # cascade: list → OR
767-
else:
768-
ft._restriction = restr # restrict: AndList → AND
804+
ft = self._restrict_freetable(ft, restrictions[node], mode=mode)
769805
result[node] = len(ft)
770806

771807
for t, count in result.items():
@@ -789,19 +825,15 @@ def prune(self):
789825

790826
result = Diagram(self)
791827
restrictions = result._cascade_restrictions or result._restrict_conditions
828+
mode = "cascade" if result._cascade_restrictions else "restrict"
792829

793830
if restrictions:
794831
# Restricted: check row counts under restriction
795832
for node in list(restrictions):
796833
if node.isdigit():
797834
continue
798835
ft = FreeTable(self._connection, node)
799-
restr = restrictions[node]
800-
if restr:
801-
if isinstance(restr, list) and not isinstance(restr, AndList):
802-
ft.restrict_in_place(restr)
803-
else:
804-
ft._restriction = restr
836+
ft = self._restrict_freetable(ft, restrictions[node], mode=mode)
805837
if len(ft) == 0:
806838
restrictions.pop(node)
807839
result._restriction_attrs.pop(node, None)

src/datajoint/table.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,8 @@ def delete(
973973
transaction: bool = True,
974974
prompt: bool | None = None,
975975
part_integrity: str = "enforce",
976-
) -> int:
976+
dry_run: bool = False,
977+
) -> int | dict[str, int]:
977978
"""
978979
Deletes the contents of the table and its dependent tables, recursively.
979980
@@ -991,9 +992,12 @@ def delete(
991992
- ``"enforce"`` (default): Error if parts would be deleted without masters.
992993
- ``"ignore"``: Allow deleting parts without masters (breaks integrity).
993994
- ``"cascade"``: Also delete masters when parts are deleted (maintains integrity).
995+
dry_run: If `True`, return a dict mapping full table names to affected
996+
row counts without deleting any data. Default False.
994997
995998
Returns:
996-
Number of deleted rows (excluding those from dependent tables).
999+
Number of deleted rows (excluding those from dependent tables), or
1000+
(if ``dry_run``) a dict mapping full table name to affected row count.
9971001
9981002
Raises:
9991003
DataJointError: When deleting within an existing transaction.
@@ -1006,7 +1010,7 @@ def delete(
10061010

10071011
diagram = Diagram._from_table(self)
10081012
diagram = diagram.cascade(self, part_integrity=part_integrity)
1009-
return diagram.delete(transaction=transaction, prompt=prompt)
1013+
return diagram.delete(transaction=transaction, prompt=prompt, dry_run=dry_run)
10101014

10111015
def drop_quick(self):
10121016
"""
@@ -1046,7 +1050,7 @@ def drop_quick(self):
10461050
else:
10471051
logger.info("Nothing to drop: table %s is not declared" % self.full_table_name)
10481052

1049-
def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
1053+
def drop(self, prompt: bool | None = None, part_integrity: str = "enforce", dry_run: bool = False):
10501054
"""
10511055
Drop the table and all tables that reference it, recursively.
10521056
@@ -1059,6 +1063,12 @@ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
10591063
part_integrity: Policy for master-part integrity. One of:
10601064
- ``"enforce"`` (default): Error if parts would be dropped without masters.
10611065
- ``"ignore"``: Allow dropping parts without masters.
1066+
dry_run: If `True`, return a dict mapping full table names to row
1067+
counts without dropping any tables. Default False.
1068+
1069+
Returns:
1070+
dict[str, int] or None: If ``dry_run``, mapping of full table name
1071+
to row count. Otherwise None.
10621072
"""
10631073
if self.restriction:
10641074
raise DataJointError(
@@ -1067,7 +1077,7 @@ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
10671077
from .diagram import Diagram
10681078

10691079
diagram = Diagram._from_table(self)
1070-
diagram.drop(prompt=prompt, part_integrity=part_integrity)
1080+
return diagram.drop(prompt=prompt, part_integrity=part_integrity, dry_run=dry_run)
10711081

10721082
def describe(self, context=None, printout=False):
10731083
"""

src/datajoint/user_tables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def delete(self, part_integrity: str = "enforce", **kwargs):
239239
)
240240
super().delete(part_integrity=part_integrity, **kwargs)
241241

242-
def drop(self, part_integrity: str = "enforce"):
242+
def drop(self, part_integrity: str = "enforce", dry_run: bool = False):
243243
"""
244244
Drop a Part table.
245245
@@ -248,12 +248,13 @@ def drop(self, part_integrity: str = "enforce"):
248248
- ``"enforce"`` (default): Error - drop master instead.
249249
- ``"ignore"``: Allow direct drop (breaks master-part structure).
250250
Note: ``"cascade"`` is not supported for drop (too destructive).
251+
dry_run: If `True`, return row counts without dropping. Default False.
251252
252253
Raises:
253254
DataJointError: If part_integrity="enforce" (direct Part drops prohibited)
254255
"""
255256
if part_integrity == "ignore":
256-
super().drop(part_integrity="ignore")
257+
return super().drop(part_integrity="ignore", dry_run=dry_run)
257258
elif part_integrity == "enforce":
258259
raise DataJointError("Cannot drop a Part directly. Drop master instead, or use part_integrity='ignore' to force.")
259260
else:

0 commit comments

Comments
 (0)