|
| 1 | +import pytest as _pytest |
| 2 | + |
1 | 3 | import datajoint as dj |
2 | 4 |
|
3 | | -from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L |
| 5 | +from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L, Profile, Website |
4 | 6 |
|
5 | 7 |
|
6 | 8 | def test_decorator(schema_simp): |
@@ -61,3 +63,105 @@ def test_part_table_parsing(schema_simp): |
61 | 63 | graph = erd._make_graph() |
62 | 64 | assert "OutfitLaunch" in graph.nodes() |
63 | 65 | assert "OutfitLaunch.OutfitPiece" in graph.nodes() |
| 66 | + |
| 67 | + |
| 68 | +# --- prune() tests --- |
| 69 | + |
| 70 | + |
| 71 | +@_pytest.fixture |
| 72 | +def schema_simp_pop(schema_simp): |
| 73 | + """Populate the simple schema for prune tests.""" |
| 74 | + Profile().delete() |
| 75 | + Website().delete() |
| 76 | + G().delete() |
| 77 | + E().delete() |
| 78 | + D().delete() |
| 79 | + B().delete() |
| 80 | + L().delete() |
| 81 | + A().delete() |
| 82 | + |
| 83 | + A().insert(A.contents, skip_duplicates=True) |
| 84 | + L().insert(L.contents, skip_duplicates=True) |
| 85 | + B().populate() |
| 86 | + D().populate() |
| 87 | + E().populate() |
| 88 | + G().populate() |
| 89 | + yield schema_simp |
| 90 | + |
| 91 | + |
| 92 | +def test_prune_unrestricted(schema_simp_pop): |
| 93 | + """Prune on unrestricted diagram removes physically empty tables.""" |
| 94 | + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) |
| 95 | + original_count = len(diag.nodes_to_show) |
| 96 | + pruned = diag.prune() |
| 97 | + |
| 98 | + # Populated tables (A, L, B, B.C, D, E, E.F, G, etc.) should survive |
| 99 | + for cls in (A, B, D, E, L): |
| 100 | + assert cls.full_table_name in pruned.nodes_to_show, f"{cls.__name__} should not be pruned" |
| 101 | + |
| 102 | + # Empty tables like Profile should be removed |
| 103 | + assert Profile.full_table_name not in pruned.nodes_to_show, "empty Profile should be pruned" |
| 104 | + |
| 105 | + # Pruned diagram should have fewer nodes |
| 106 | + assert len(pruned.nodes_to_show) < original_count |
| 107 | + |
| 108 | + |
| 109 | +def test_prune_after_restrict(schema_simp_pop): |
| 110 | + """Prune after restrict removes tables with zero matching rows.""" |
| 111 | + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) |
| 112 | + restricted = diag.restrict(A & "id_a=0") |
| 113 | + counts = restricted.preview() |
| 114 | + |
| 115 | + pruned = restricted.prune() |
| 116 | + pruned_counts = pruned.preview() |
| 117 | + |
| 118 | + # Every table in pruned preview should have > 0 rows |
| 119 | + assert all(c > 0 for c in pruned_counts.values()), "pruned diagram should have no zero-count tables" |
| 120 | + |
| 121 | + # Tables with zero rows in the original preview should be gone |
| 122 | + for table, count in counts.items(): |
| 123 | + if count == 0: |
| 124 | + assert table not in pruned._restrict_conditions, f"{table} had 0 rows but was not pruned" |
| 125 | + |
| 126 | + |
| 127 | +def test_prune_after_cascade(schema_simp_pop): |
| 128 | + """Prune after cascade removes tables with zero matching rows.""" |
| 129 | + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) |
| 130 | + cascaded = diag.cascade(A & "id_a=0") |
| 131 | + counts = cascaded.preview() |
| 132 | + |
| 133 | + pruned = cascaded.prune() |
| 134 | + pruned_counts = pruned.preview() |
| 135 | + |
| 136 | + assert all(c > 0 for c in pruned_counts.values()) |
| 137 | + |
| 138 | + for table, count in counts.items(): |
| 139 | + if count == 0: |
| 140 | + assert table not in pruned._cascade_restrictions, f"{table} had 0 rows but was not pruned" |
| 141 | + |
| 142 | + |
| 143 | +def test_prune_idempotent(schema_simp_pop): |
| 144 | + """Pruning twice gives the same result.""" |
| 145 | + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) |
| 146 | + restricted = diag.restrict(A & "id_a=0") |
| 147 | + pruned_once = restricted.prune() |
| 148 | + pruned_twice = pruned_once.prune() |
| 149 | + |
| 150 | + assert pruned_once.nodes_to_show == pruned_twice.nodes_to_show |
| 151 | + assert set(pruned_once._restrict_conditions) == set(pruned_twice._restrict_conditions) |
| 152 | + |
| 153 | + |
| 154 | +def test_prune_then_restrict(schema_simp_pop): |
| 155 | + """Restrict can be called after prune.""" |
| 156 | + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) |
| 157 | + pruned = diag.restrict(A & "id_a < 5").prune() |
| 158 | + # Restrict again on the same seed table with a tighter condition |
| 159 | + further = pruned.restrict(A & "id_a=0") |
| 160 | + |
| 161 | + # Should not raise; further restriction should narrow results |
| 162 | + counts = further.preview() |
| 163 | + assert all(c >= 0 for c in counts.values()) |
| 164 | + # Tighter restriction should produce fewer or equal rows |
| 165 | + pruned_counts = pruned.preview() |
| 166 | + for table in counts: |
| 167 | + assert counts[table] <= pruned_counts.get(table, 0) |
0 commit comments