Skip to content

Commit afd1f63

Browse files
fix: cascade delete uses restrict() for proper SQL generation and OR convergence
Replace direct `_restriction` assignment with `restrict()` calls in Diagram so that AndList and QueryExpression objects are converted to valid SQL via `make_condition()`. Cascade delete now uses OR convergence (a row is deleted if ANY FK reference points to a deleted row), while restrict/export uses AND. Also adds dry_run support to delete() and drop(), fixes CLI test subprocess invocation to use `sys.executable -m datajoint.cli`, and fixes test fixture cleanup to respect part_integrity policy. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ae0eddd commit afd1f63

File tree

7 files changed

+1049
-242
lines changed

7 files changed

+1049
-242
lines changed

pixi.lock

Lines changed: 869 additions & 203 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/datajoint/diagram.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(self, source, context=None) -> None:
9797
self._cascade_restrictions = copy_module.deepcopy(source._cascade_restrictions)
9898
self._restrict_conditions = copy_module.deepcopy(source._restrict_conditions)
9999
self._restriction_attrs = copy_module.deepcopy(source._restriction_attrs)
100+
self._part_integrity = getattr(source, "_part_integrity", "enforce")
100101
super().__init__(source)
101102
return
102103

@@ -369,6 +370,7 @@ def cascade(self, table_expr, part_integrity="enforce"):
369370
"cascade and restrict modes are mutually exclusive."
370371
)
371372
result = Diagram(self)
373+
result._part_integrity = part_integrity
372374
node = table_expr.full_table_name
373375
if node not in result.nodes():
374376
raise DataJointError(f"Table {node} is not in the diagram.")
@@ -380,6 +382,34 @@ def cascade(self, table_expr, part_integrity="enforce"):
380382
result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity)
381383
return result
382384

385+
@staticmethod
386+
def _restrict_freetable(ft, restrictions, mode="cascade"):
387+
"""
388+
Apply cascade/restrict restrictions to a FreeTable.
389+
390+
Uses ``restrict()`` to properly convert each restriction (AndList,
391+
QueryExpression, etc.) into SQL via ``make_condition``, rather than
392+
assigning raw objects to ``_restriction`` which would produce
393+
invalid SQL in ``where_clause``.
394+
395+
For cascade mode (delete), restrictions from different parent edges
396+
are OR-ed: a row is deleted if ANY of its FK references point to a
397+
deleted row.
398+
399+
For restrict mode (export), restrictions are AND-ed: a row is
400+
included only if ALL ancestor conditions are satisfied.
401+
"""
402+
if not restrictions:
403+
return ft
404+
if mode == "cascade":
405+
# OR semantics — passing a list to restrict() creates an OrList
406+
return ft.restrict(restrictions)
407+
else:
408+
# AND semantics — each restriction narrows further
409+
for r in restrictions:
410+
ft = ft.restrict(r)
411+
return ft
412+
383413
def restrict(self, table_expr):
384414
"""
385415
Apply restrict condition and propagate downstream.
@@ -445,11 +475,8 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
445475
# Build parent FreeTable with current restriction
446476
parent_ft = FreeTable(self._connection, node)
447477
restr = restrictions[node]
448-
if mode == "cascade" and restr:
449-
parent_ft._restriction = restr # plain list → OR
450-
elif mode == "restrict":
451-
parent_ft._restriction = restr # AndList → AND
452-
# else: cascade with empty list → unrestricted
478+
if restr:
479+
parent_ft = self._restrict_freetable(parent_ft, restr, mode=mode)
453480

454481
parent_attrs = self._restriction_attrs.get(node, set())
455482

@@ -507,14 +534,14 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
507534
child_ft = FreeTable(self._connection, target)
508535
child_restr = restrictions.get(target, [])
509536
if child_restr:
510-
child_ft._restriction = child_restr
537+
child_ft = self._restrict_freetable(child_ft, child_restr, mode=mode)
511538
master_ft = FreeTable(self._connection, master_name)
512539
from .condition import make_condition
513540

514541
master_restr = make_condition(
515542
master_ft,
516543
(master_ft.proj() & child_ft.proj()).to_arrays(),
517-
master_ft._restriction_attributes,
544+
master_ft.restriction_attributes,
518545
)
519546
restrictions[master_name] = [master_restr]
520547
self._restriction_attrs[master_name] = set()
@@ -579,7 +606,7 @@ def _apply_propagation_rule(
579606

580607
self._restriction_attrs.setdefault(child_node, set()).update(child_attrs)
581608

582-
def delete(self, transaction=True, prompt=None):
609+
def delete(self, transaction=True, prompt=None, dry_run=False):
583610
"""
584611
Execute cascading delete using cascade restrictions.
585612
@@ -589,14 +616,20 @@ def delete(self, transaction=True, prompt=None):
589616
Wrap in a transaction. Default True.
590617
prompt : bool or None, optional
591618
Show preview and ask confirmation. Default ``dj.config['safemode']``.
619+
dry_run : bool, optional
620+
If True, return affected row counts without deleting. Default False.
592621
593622
Returns
594623
-------
595-
int
596-
Number of rows deleted from the root table.
624+
int or dict[str, int]
625+
Number of rows deleted from the root table, or (if ``dry_run``)
626+
a mapping of full table name to affected row count.
597627
"""
598628
from .table import FreeTable
599629

630+
if dry_run:
631+
return self.preview()
632+
600633
prompt = self._connection._config["safemode"] if prompt is None else prompt
601634

602635
if not self._cascade_restrictions:
@@ -606,14 +639,15 @@ def delete(self, transaction=True, prompt=None):
606639

607640
# Pre-check part_integrity="enforce": ensure no part is deleted
608641
# before its master
609-
for node in self._cascade_restrictions:
610-
master = extract_master(node)
611-
if master and master not in self._cascade_restrictions:
612-
raise DataJointError(
613-
f"Attempt to delete part table {node} before "
614-
f"its master {master}. Delete from the master first, "
615-
f"or use part_integrity='ignore' or 'cascade'."
616-
)
642+
if getattr(self, "_part_integrity", "enforce") == "enforce":
643+
for node in self._cascade_restrictions:
644+
master = extract_master(node)
645+
if master and master not in self._cascade_restrictions:
646+
raise DataJointError(
647+
f"Attempt to delete part table {node} before "
648+
f"its master {master}. Delete from the master first, "
649+
f"or use part_integrity='ignore' or 'cascade'."
650+
)
617651

618652
# Get non-alias nodes with restrictions in topological order
619653
all_sorted = topo_sort(self)
@@ -623,9 +657,7 @@ def delete(self, transaction=True, prompt=None):
623657
if prompt:
624658
for t in tables:
625659
ft = FreeTable(conn, t)
626-
restr = self._cascade_restrictions[t]
627-
if restr:
628-
ft._restriction = restr
660+
ft = self._restrict_freetable(ft, self._cascade_restrictions[t])
629661
logger.info("{table} ({count} tuples)".format(table=t, count=len(ft)))
630662

631663
# Start transaction
@@ -647,9 +679,7 @@ def delete(self, transaction=True, prompt=None):
647679
try:
648680
for table_name in reversed(tables):
649681
ft = FreeTable(conn, table_name)
650-
restr = self._cascade_restrictions[table_name]
651-
if restr:
652-
ft._restriction = restr
682+
ft = self._restrict_freetable(ft, self._cascade_restrictions[table_name])
653683
count = ft.delete_quick(get_count=True)
654684
logger.info("Deleting {count} rows from {table}".format(count=count, table=table_name))
655685
if table_name == tables[0]:
@@ -692,7 +722,7 @@ def delete(self, transaction=True, prompt=None):
692722
root_count = 0
693723
return root_count
694724

695-
def drop(self, prompt=None, part_integrity="enforce"):
725+
def drop(self, prompt=None, part_integrity="enforce", dry_run=False):
696726
"""
697727
Drop all tables in the diagram in reverse topological order.
698728
@@ -702,6 +732,13 @@ def drop(self, prompt=None, part_integrity="enforce"):
702732
Show preview and ask confirmation. Default ``dj.config['safemode']``.
703733
part_integrity : str, optional
704734
``"enforce"`` (default) or ``"ignore"``.
735+
dry_run : bool, optional
736+
If True, return row counts without dropping. Default False.
737+
738+
Returns
739+
-------
740+
dict[str, int] or None
741+
If ``dry_run``, mapping of full table name to row count.
705742
"""
706743
from .table import FreeTable
707744

@@ -720,6 +757,14 @@ def drop(self, prompt=None, part_integrity="enforce"):
720757
)
721758
)
722759

760+
if dry_run:
761+
result = {}
762+
for t in tables:
763+
count = len(FreeTable(conn, t))
764+
result[t] = count
765+
logger.info("{table} ({count} tuples)".format(table=t, count=count))
766+
return result
767+
723768
do_drop = True
724769
if prompt:
725770
for t in tables:
@@ -742,6 +787,7 @@ def preview(self):
742787
from .table import FreeTable
743788

744789
restrictions = self._cascade_restrictions or self._restrict_conditions
790+
mode = "cascade" if self._cascade_restrictions else "restrict"
745791
if not restrictions:
746792
raise DataJointError("No restrictions applied. " "Call cascade() or restrict() first.")
747793

@@ -750,9 +796,7 @@ def preview(self):
750796
if node.isdigit() or node not in restrictions:
751797
continue
752798
ft = FreeTable(self._connection, node)
753-
restr = restrictions[node]
754-
if restr:
755-
ft._restriction = restr
799+
ft = self._restrict_freetable(ft, restrictions[node], mode=mode)
756800
result[node] = len(ft)
757801

758802
for t, count in result.items():

src/datajoint/table.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,8 @@ def delete(
973973
transaction: bool = True,
974974
prompt: bool | None = None,
975975
part_integrity: str = "enforce",
976-
) -> int:
976+
dry_run: bool = False,
977+
) -> int | dict[str, int]:
977978
"""
978979
Deletes the contents of the table and its dependent tables, recursively.
979980
@@ -991,9 +992,12 @@ def delete(
991992
- ``"enforce"`` (default): Error if parts would be deleted without masters.
992993
- ``"ignore"``: Allow deleting parts without masters (breaks integrity).
993994
- ``"cascade"``: Also delete masters when parts are deleted (maintains integrity).
995+
dry_run: If `True`, return a dict mapping full table names to affected
996+
row counts without deleting any data. Default False.
994997
995998
Returns:
996-
Number of deleted rows (excluding those from dependent tables).
999+
Number of deleted rows (excluding those from dependent tables), or
1000+
(if ``dry_run``) a dict mapping full table name to affected row count.
9971001
9981002
Raises:
9991003
DataJointError: When deleting within an existing transaction.
@@ -1006,7 +1010,7 @@ def delete(
10061010

10071011
diagram = Diagram._from_table(self)
10081012
diagram = diagram.cascade(self, part_integrity=part_integrity)
1009-
return diagram.delete(transaction=transaction, prompt=prompt)
1013+
return diagram.delete(transaction=transaction, prompt=prompt, dry_run=dry_run)
10101014

10111015
def drop_quick(self):
10121016
"""
@@ -1046,7 +1050,7 @@ def drop_quick(self):
10461050
else:
10471051
logger.info("Nothing to drop: table %s is not declared" % self.full_table_name)
10481052

1049-
def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
1053+
def drop(self, prompt: bool | None = None, part_integrity: str = "enforce", dry_run: bool = False):
10501054
"""
10511055
Drop the table and all tables that reference it, recursively.
10521056
@@ -1059,6 +1063,12 @@ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
10591063
part_integrity: Policy for master-part integrity. One of:
10601064
- ``"enforce"`` (default): Error if parts would be dropped without masters.
10611065
- ``"ignore"``: Allow dropping parts without masters.
1066+
dry_run: If `True`, return a dict mapping full table names to row
1067+
counts without dropping any tables. Default False.
1068+
1069+
Returns:
1070+
dict[str, int] or None: If ``dry_run``, mapping of full table name
1071+
to row count. Otherwise None.
10621072
"""
10631073
if self.restriction:
10641074
raise DataJointError(
@@ -1067,7 +1077,7 @@ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
10671077
from .diagram import Diagram
10681078

10691079
diagram = Diagram._from_table(self)
1070-
diagram.drop(prompt=prompt, part_integrity=part_integrity)
1080+
return diagram.drop(prompt=prompt, part_integrity=part_integrity, dry_run=dry_run)
10711081

10721082
def describe(self, context=None, printout=False):
10731083
"""

src/datajoint/user_tables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def delete(self, part_integrity: str = "enforce", **kwargs):
239239
)
240240
super().delete(part_integrity=part_integrity, **kwargs)
241241

242-
def drop(self, part_integrity: str = "enforce"):
242+
def drop(self, part_integrity: str = "enforce", dry_run: bool = False):
243243
"""
244244
Drop a Part table.
245245
@@ -248,12 +248,13 @@ def drop(self, part_integrity: str = "enforce"):
248248
- ``"enforce"`` (default): Error - drop master instead.
249249
- ``"ignore"``: Allow direct drop (breaks master-part structure).
250250
Note: ``"cascade"`` is not supported for drop (too destructive).
251+
dry_run: If `True`, return row counts without dropping. Default False.
251252
252253
Raises:
253254
DataJointError: If part_integrity="enforce" (direct Part drops prohibited)
254255
"""
255256
if part_integrity == "ignore":
256-
super().drop(part_integrity="ignore")
257+
return super().drop(part_integrity="ignore", dry_run=dry_run)
257258
elif part_integrity == "enforce":
258259
raise DataJointError("Cannot drop a Part directly. Drop master instead, or use part_integrity='ignore' to force.")
259260
else:

tests/integration/test_cascade_delete.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,86 @@ class Observation(dj.Manual):
188188
assert remaining_obs[0]["obs_id"] == 3
189189
assert remaining_obs[0]["subject_id"] == 2
190190
assert remaining_obs[0]["measurement"] == 15.3
191+
192+
193+
def test_delete_dry_run(schema_by_backend):
194+
"""dry_run=True returns affected row counts without deleting data."""
195+
196+
@schema_by_backend
197+
class Parent(dj.Manual):
198+
definition = """
199+
parent_id : int
200+
---
201+
name : varchar(255)
202+
"""
203+
204+
@schema_by_backend
205+
class Child(dj.Manual):
206+
definition = """
207+
-> Parent
208+
child_id : int
209+
---
210+
data : varchar(255)
211+
"""
212+
213+
Parent.insert1((1, "P1"))
214+
Parent.insert1((2, "P2"))
215+
Child.insert1((1, 1, "C1-1"))
216+
Child.insert1((1, 2, "C1-2"))
217+
Child.insert1((2, 1, "C2-1"))
218+
219+
# dry_run on restricted delete
220+
counts = (Parent & {"parent_id": 1}).delete(dry_run=True)
221+
222+
assert isinstance(counts, dict)
223+
assert counts[Parent.full_table_name] == 1
224+
assert counts[Child.full_table_name] == 2
225+
226+
# Data must still be intact
227+
assert len(Parent()) == 2
228+
assert len(Child()) == 3
229+
230+
# dry_run on unrestricted delete
231+
counts_all = Parent.delete(dry_run=True)
232+
assert counts_all[Parent.full_table_name] == 2
233+
assert counts_all[Child.full_table_name] == 3
234+
235+
# Still intact
236+
assert len(Parent()) == 2
237+
assert len(Child()) == 3
238+
239+
240+
def test_drop_dry_run(schema_by_backend):
241+
"""dry_run=True returns row counts without dropping tables."""
242+
243+
@schema_by_backend
244+
class Parent(dj.Manual):
245+
definition = """
246+
parent_id : int
247+
---
248+
name : varchar(255)
249+
"""
250+
251+
@schema_by_backend
252+
class Child(dj.Manual):
253+
definition = """
254+
-> Parent
255+
child_id : int
256+
---
257+
data : varchar(255)
258+
"""
259+
260+
Parent.insert1((1, "P1"))
261+
Child.insert1((1, 1, "C1"))
262+
263+
counts = Parent.drop(dry_run=True)
264+
265+
assert isinstance(counts, dict)
266+
assert counts[Parent.full_table_name] == 1
267+
assert counts[Child.full_table_name] == 1
268+
269+
# Tables must still exist and have data
270+
assert Parent.is_declared
271+
assert Child.is_declared
272+
assert len(Parent()) == 1
273+
assert len(Child()) == 1

0 commit comments

Comments
 (0)