Skip to content

Commit 0bede1d

Browse files
feat: implement Diagram.prune() to remove empty tables
Adds prune() method that removes tables with zero matching rows from the diagram. Without prior restrictions, removes physically empty tables. With restrictions (cascade or restrict), removes tables where the restricted query yields zero rows. Returns a new Diagram. Includes 5 integration tests: unrestricted prune, prune after restrict, prune after cascade, idempotency, and prune-then-restrict chaining. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b88ede7 commit 0bede1d

File tree

2 files changed

+150
-1
lines changed

2 files changed

+150
-1
lines changed

src/datajoint/diagram.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,51 @@ def preview(self):
772772
logger.info("{table} ({count} tuples)".format(table=t, count=count))
773773
return result
774774

775+
def prune(self):
776+
"""
777+
Remove tables with zero matching rows from the diagram.
778+
779+
Without prior restrictions, removes physically empty tables.
780+
With restrictions (``cascade()`` or ``restrict()``), removes
781+
tables where the restricted query yields zero rows.
782+
783+
Returns
784+
-------
785+
Diagram
786+
New Diagram with empty tables removed.
787+
"""
788+
from .table import FreeTable
789+
790+
result = Diagram(self)
791+
restrictions = result._cascade_restrictions or result._restrict_conditions
792+
793+
if restrictions:
794+
# Restricted: check row counts under restriction
795+
for node in list(restrictions):
796+
if node.isdigit():
797+
continue
798+
ft = FreeTable(self._connection, node)
799+
restr = restrictions[node]
800+
if restr:
801+
if isinstance(restr, list) and not isinstance(restr, AndList):
802+
ft.restrict_in_place(restr)
803+
else:
804+
ft._restriction = restr
805+
if len(ft) == 0:
806+
restrictions.pop(node)
807+
result._restriction_attrs.pop(node, None)
808+
result.nodes_to_show.discard(node)
809+
else:
810+
# Unrestricted: check physical row counts
811+
for node in list(result.nodes_to_show):
812+
if node.isdigit():
813+
continue
814+
ft = FreeTable(self._connection, node)
815+
if len(ft) == 0:
816+
result.nodes_to_show.discard(node)
817+
818+
return result
819+
775820
def _make_graph(self) -> nx.DiGraph:
776821
"""
777822
Build graph object ready for drawing.

tests/integration/test_erd.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import pytest as _pytest
2+
13
import datajoint as dj
24

3-
from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L
5+
from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L, Profile, Website
46

57

68
def test_decorator(schema_simp):
@@ -61,3 +63,105 @@ def test_part_table_parsing(schema_simp):
6163
graph = erd._make_graph()
6264
assert "OutfitLaunch" in graph.nodes()
6365
assert "OutfitLaunch.OutfitPiece" in graph.nodes()
66+
67+
68+
# --- prune() tests ---
69+
70+
71+
@_pytest.fixture
72+
def schema_simp_pop(schema_simp):
73+
"""Populate the simple schema for prune tests."""
74+
Profile().delete()
75+
Website().delete()
76+
G().delete()
77+
E().delete()
78+
D().delete()
79+
B().delete()
80+
L().delete()
81+
A().delete()
82+
83+
A().insert(A.contents, skip_duplicates=True)
84+
L().insert(L.contents, skip_duplicates=True)
85+
B().populate()
86+
D().populate()
87+
E().populate()
88+
G().populate()
89+
yield schema_simp
90+
91+
92+
def test_prune_unrestricted(schema_simp_pop):
93+
"""Prune on unrestricted diagram removes physically empty tables."""
94+
diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
95+
original_count = len(diag.nodes_to_show)
96+
pruned = diag.prune()
97+
98+
# Populated tables (A, L, B, B.C, D, E, E.F, G, etc.) should survive
99+
for cls in (A, B, D, E, L):
100+
assert cls.full_table_name in pruned.nodes_to_show, f"{cls.__name__} should not be pruned"
101+
102+
# Empty tables like Profile should be removed
103+
assert Profile.full_table_name not in pruned.nodes_to_show, "empty Profile should be pruned"
104+
105+
# Pruned diagram should have fewer nodes
106+
assert len(pruned.nodes_to_show) < original_count
107+
108+
109+
def test_prune_after_restrict(schema_simp_pop):
110+
"""Prune after restrict removes tables with zero matching rows."""
111+
diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
112+
restricted = diag.restrict(A & "id_a=0")
113+
counts = restricted.preview()
114+
115+
pruned = restricted.prune()
116+
pruned_counts = pruned.preview()
117+
118+
# Every table in pruned preview should have > 0 rows
119+
assert all(c > 0 for c in pruned_counts.values()), "pruned diagram should have no zero-count tables"
120+
121+
# Tables with zero rows in the original preview should be gone
122+
for table, count in counts.items():
123+
if count == 0:
124+
assert table not in pruned._restrict_conditions, f"{table} had 0 rows but was not pruned"
125+
126+
127+
def test_prune_after_cascade(schema_simp_pop):
128+
"""Prune after cascade removes tables with zero matching rows."""
129+
diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
130+
cascaded = diag.cascade(A & "id_a=0")
131+
counts = cascaded.preview()
132+
133+
pruned = cascaded.prune()
134+
pruned_counts = pruned.preview()
135+
136+
assert all(c > 0 for c in pruned_counts.values())
137+
138+
for table, count in counts.items():
139+
if count == 0:
140+
assert table not in pruned._cascade_restrictions, f"{table} had 0 rows but was not pruned"
141+
142+
143+
def test_prune_idempotent(schema_simp_pop):
144+
"""Pruning twice gives the same result."""
145+
diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
146+
restricted = diag.restrict(A & "id_a=0")
147+
pruned_once = restricted.prune()
148+
pruned_twice = pruned_once.prune()
149+
150+
assert pruned_once.nodes_to_show == pruned_twice.nodes_to_show
151+
assert set(pruned_once._restrict_conditions) == set(pruned_twice._restrict_conditions)
152+
153+
154+
def test_prune_then_restrict(schema_simp_pop):
155+
"""Restrict can be called after prune."""
156+
diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
157+
pruned = diag.restrict(A & "id_a < 5").prune()
158+
# Restrict again on the same seed table with a tighter condition
159+
further = pruned.restrict(A & "id_a=0")
160+
161+
# Should not raise; further restriction should narrow results
162+
counts = further.preview()
163+
assert all(c >= 0 for c in counts.values())
164+
# Tighter restriction should produce fewer or equal rows
165+
pruned_counts = pruned.preview()
166+
for table in counts:
167+
assert counts[table] <= pruned_counts.get(table, 0)

0 commit comments

Comments
 (0)