Skip to content

Commit d2626e0

Browse files
fix: use post-hoc enforce check matching old Table.delete() behavior
The pre-check on the cascade graph was too conservative — it flagged part tables that appeared in the graph but had zero rows to delete. The old code checked actual deletions within a transaction. Replace the graph-based pre-check with a post-hoc check on deleted_tables (tables that actually had rows deleted). If a part table had rows deleted without its master also having rows deleted, roll back the transaction and raise DataJointError. This matches the original part_integrity="enforce" semantics. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a2d2693 commit d2626e0

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

src/datajoint/diagram.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def __init__(self, source, context=None) -> None:
9898
self._restrict_conditions = copy_module.deepcopy(source._restrict_conditions)
9999
self._restriction_attrs = copy_module.deepcopy(source._restriction_attrs)
100100
self._part_integrity = source._part_integrity
101-
self._cascade_seed = source._cascade_seed
102101
super().__init__(source)
103102
return
104103

@@ -127,7 +126,6 @@ def __init__(self, source, context=None) -> None:
127126
self._restrict_conditions = {}
128127
self._restriction_attrs = {}
129128
self._part_integrity = "enforce"
130-
self._cascade_seed = None
131129

132130
# Enumerate nodes from all the items in the list
133131
self.nodes_to_show = set()
@@ -196,7 +194,6 @@ def _from_table(cls, table_expr) -> "Diagram":
196194
result._restrict_conditions = {}
197195
result._restriction_attrs = {}
198196
result._part_integrity = "enforce"
199-
result._cascade_seed = None
200197
return result
201198

202199
def add_parts(self) -> "Diagram":
@@ -376,7 +373,6 @@ def cascade(self, table_expr, part_integrity="enforce"):
376373
)
377374
result = Diagram(self)
378375
result._part_integrity = part_integrity
379-
result._cascade_seed = table_expr.full_table_name
380376
node = table_expr.full_table_name
381377
if node not in result.nodes():
382378
raise DataJointError(f"Table {node} is not in the diagram.")
@@ -612,20 +608,6 @@ def delete(self, transaction=True, prompt=None):
612608

613609
conn = self._connection
614610

615-
# Pre-check part_integrity="enforce": ensure no part is deleted
616-
# before its master (skip the cascade seed — explicitly targeted)
617-
if self._part_integrity == "enforce":
618-
for node in self._cascade_restrictions:
619-
if node == self._cascade_seed:
620-
continue
621-
master = extract_master(node)
622-
if master and master not in self._cascade_restrictions:
623-
raise DataJointError(
624-
f"Attempt to delete part table {node} before "
625-
f"its master {master}. Delete from the master first, "
626-
f"or use part_integrity='ignore' or 'cascade'."
627-
)
628-
629611
# Get non-alias nodes with restrictions in topological order
630612
all_sorted = topo_sort(self)
631613
tables = [t for t in all_sorted if not t.isdigit() and t in self._cascade_restrictions]
@@ -655,13 +637,16 @@ def delete(self, transaction=True, prompt=None):
655637

656638
# Execute deletes in reverse topological order (leaves first)
657639
root_count = 0
640+
deleted_tables = set()
658641
try:
659642
for table_name in reversed(tables):
660643
ft = FreeTable(conn, table_name)
661644
restr = self._cascade_restrictions[table_name]
662645
if restr:
663646
ft.restrict_in_place(restr)
664647
count = ft.delete_quick(get_count=True)
648+
if count > 0:
649+
deleted_tables.add(table_name)
665650
logger.info("Deleting {count} rows from {table}".format(count=count, table=table_name))
666651
if table_name == tables[0]:
667652
root_count = count
@@ -681,6 +666,20 @@ def delete(self, transaction=True, prompt=None):
681666
conn.cancel_transaction()
682667
raise
683668

669+
# Post-check part_integrity="enforce": roll back if a part table
670+
# had rows deleted without its master also having rows deleted.
671+
if self._part_integrity == "enforce" and deleted_tables:
672+
for table_name in deleted_tables:
673+
master = extract_master(table_name)
674+
if master and master not in deleted_tables:
675+
if transaction:
676+
conn.cancel_transaction()
677+
raise DataJointError(
678+
f"Attempt to delete part table {table_name} before "
679+
f"its master {master}. Delete from the master first, "
680+
f"or use part_integrity='ignore' or 'cascade'."
681+
)
682+
684683
# Confirm and commit
685684
if root_count == 0:
686685
if prompt:

0 commit comments

Comments
 (0)