Skip to content

Commit 4e7241d

Browse files
authored
fix: Django fixer E2E tests, --since docs, mrt explain file check order + tests (#44)
1 parent 7c07e5d commit 4e7241d

30 files changed

Lines changed: 1768 additions & 440 deletions

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,11 @@ Add to `.pre-commit-config.yaml` to run `mrt check` automatically before every p
174174
Check only migrations added since a given revision. Keeps CI fast on large codebases:
175175

176176
```bash
177-
mrt check migrations/versions/ --since main
178-
mrt check myapp/migrations/ --since v1.2.0
177+
# Alembic — pass a revision ID
178+
mrt check migrations/versions/ --since a1b2c3d4
179+
180+
# Django — pass app_label.migration_name
181+
mrt check myapp/migrations/ --since myapp.0010_add_email
179182
```
180183

181184
## CI/CD integration

docs/cli.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,21 @@ mrt check migrations/versions/
6161

6262
#### `--since` — incremental scanning
6363

64-
In large codebases, scanning the full migration history on every PR is wasteful. `--since` limits the scan to migrations whose files were added after the given git ref:
64+
In large codebases, scanning the full migration history on every PR is wasteful. `--since` limits the scan to migrations that come *after* the given revision in the migration dependency chain.
6565

66+
**Alembic** — pass a revision ID:
6667
```bash
67-
mrt check migrations/versions/ --since main
68-
mrt check myapp/migrations/ --since v1.2.0
69-
mrt check myapp/migrations/ --since HEAD~5
68+
mrt check migrations/versions/ --since a1b2c3d4
7069
```
70+
Only revisions whose `down_revision` ancestry includes `a1b2c3d4` are checked.
7171

72-
This is the recommended setup for CI: check only the migrations that changed in the current PR branch.
72+
**Django** — pass `app_label.migration_name`:
73+
```bash
74+
mrt check myapp/migrations/ --since myapp.0010_add_email
75+
```
76+
Only migrations that depend on `0010_add_email` (directly or transitively) are checked.
77+
78+
This is the recommended CI pattern: pass the last migration on the base branch so only the PR's new migrations are scanned.
7379

7480
### Exit codes
7581

pytest_mrt/commands/output.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def explain(
3838
Requires: pip install pytest-mrt[ai]
3939
Requires: ANTHROPIC_API_KEY environment variable
4040
"""
41+
path = Path(migration_file)
42+
if not path.exists():
43+
console.print(f"[red]File not found: {migration_file}[/red]")
44+
raise typer.Exit(1)
45+
4146
try:
4247
import anthropic
4348
except ImportError:
@@ -50,11 +55,6 @@ def explain(
5055
)
5156
raise typer.Exit(1)
5257

53-
path = Path(migration_file)
54-
if not path.exists():
55-
console.print(f"[red]File not found: {migration_file}[/red]")
56-
raise typer.Exit(1)
57-
5858
source = path.read_text()
5959

6060
console.print(f"[dim]Analyzing {path.name}...[/dim]")

tests/django_fixer_app/__init__.py

Whitespace-only changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from django.db import migrations, models
2+
3+
4+
class Migration(migrations.Migration):
5+
initial = True
6+
dependencies = []
7+
8+
operations = [
9+
migrations.CreateModel(
10+
name="Contact",
11+
fields=[
12+
("id", models.AutoField(primary_key=True, serialize=False)),
13+
("name", models.CharField(max_length=128)),
14+
("phone", models.CharField(blank=True, max_length=32, null=True)),
15+
],
16+
options={"app_label": "django_fixer_app"},
17+
),
18+
]
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Reference migration: matches the pattern that `mrt fix` generates for RemoveField.
2+
# Used by test_django_fixer_e2e.py to verify that the generated code pattern
3+
# actually executes correctly against a real database.
4+
from django.db import migrations
5+
6+
_MRT_TABLE = "_mrt_backups"
7+
_MRT_CHUNK = 500
8+
9+
10+
def __mrt_enc(v):
11+
import base64
12+
from datetime import date, datetime, time
13+
from decimal import Decimal
14+
from uuid import UUID
15+
16+
if v is None:
17+
return None
18+
if isinstance(v, bool):
19+
return v
20+
if isinstance(v, int):
21+
return v
22+
if isinstance(v, float):
23+
return v
24+
if isinstance(v, Decimal):
25+
return "D:" + str(v)
26+
if isinstance(v, datetime):
27+
if v.tzinfo is not None:
28+
return "DTs:" + str(v.timestamp())
29+
return "DT:" + v.isoformat()
30+
if isinstance(v, date):
31+
return "d:" + v.isoformat()
32+
if isinstance(v, time):
33+
return "t:" + v.isoformat()
34+
if isinstance(v, UUID):
35+
return "U:" + str(v)
36+
if isinstance(v, (bytes, bytearray, memoryview)):
37+
return "B:" + base64.b64encode(bytes(v)).decode()
38+
s = str(v)
39+
for prefix in ("D:", "DT:", "DTs:", "d:", "t:", "U:", "B:", "S:"):
40+
if s.startswith(prefix):
41+
return "S:" + s
42+
return s
43+
44+
45+
def __mrt_dec(v):
46+
import base64
47+
from datetime import date, datetime, time
48+
from decimal import Decimal
49+
from uuid import UUID
50+
51+
if not isinstance(v, str):
52+
return v
53+
if v.startswith("D:"):
54+
return Decimal(v[2:])
55+
if v.startswith("DTs:"):
56+
return datetime.fromtimestamp(float(v[4:]))
57+
if v.startswith("DT:"):
58+
return datetime.fromisoformat(v[3:])
59+
if v.startswith("d:"):
60+
return date.fromisoformat(v[2:])
61+
if v.startswith("t:"):
62+
return time.fromisoformat(v[2:])
63+
if v.startswith("U:"):
64+
return UUID(v[2:])
65+
if v.startswith("B:"):
66+
return base64.b64decode(v[2:])
67+
if v.startswith("S:"):
68+
return v[2:]
69+
return v
70+
71+
72+
_MRT_LABEL_contact_phone = "0002_remove_phone__contact_phone"
73+
74+
75+
def _backup_contact_phone(apps, schema_editor):
76+
import json
77+
78+
from django.db import connection
79+
80+
Contact = apps.get_model("django_fixer_app", "Contact")
81+
with connection.cursor() as cur:
82+
cur.execute(
83+
"CREATE TABLE IF NOT EXISTS " + _MRT_TABLE + " "
84+
"(migration_label TEXT NOT NULL, object_id TEXT NOT NULL, payload TEXT)"
85+
)
86+
cur.execute(
87+
"DELETE FROM " + _MRT_TABLE + " WHERE migration_label = %s",
88+
[_MRT_LABEL_contact_phone],
89+
)
90+
last_pk = None
91+
while True:
92+
qs = Contact.objects.order_by("pk")
93+
if last_pk is not None:
94+
qs = qs.filter(pk__gt=last_pk)
95+
batch = list(qs.values_list("pk", "phone")[:_MRT_CHUNK])
96+
if not batch:
97+
break
98+
with connection.cursor() as cur:
99+
for pk, val in batch:
100+
cur.execute(
101+
"INSERT INTO " + _MRT_TABLE + " VALUES (%s, %s, %s)",
102+
[
103+
_MRT_LABEL_contact_phone,
104+
json.dumps(__mrt_enc(pk)),
105+
json.dumps(__mrt_enc(val)),
106+
],
107+
)
108+
last_pk = batch[-1][0]
109+
110+
111+
def _restore_contact_phone(apps, schema_editor):
112+
import json
113+
114+
from django.db import connection
115+
116+
Contact = apps.get_model("django_fixer_app", "Contact")
117+
with connection.cursor() as cur:
118+
cur.execute(
119+
"SELECT object_id, payload FROM " + _MRT_TABLE + " WHERE migration_label = %s",
120+
[_MRT_LABEL_contact_phone],
121+
)
122+
rows = cur.fetchall()
123+
for pk_raw, val_raw in rows:
124+
pk = __mrt_dec(json.loads(pk_raw))
125+
val = __mrt_dec(json.loads(val_raw))
126+
Contact.objects.filter(pk=pk).update(phone=val)
127+
with connection.cursor() as cur:
128+
cur.execute(
129+
"DELETE FROM " + _MRT_TABLE + " WHERE migration_label = %s",
130+
[_MRT_LABEL_contact_phone],
131+
)
132+
133+
134+
class Migration(migrations.Migration):
135+
dependencies = [("django_fixer_app", "0001_initial")]
136+
137+
operations = [
138+
migrations.RunPython(_backup_contact_phone, _restore_contact_phone),
139+
migrations.RemoveField(model_name="Contact", name="phone"),
140+
]

tests/django_fixer_app/migrations/__init__.py

Whitespace-only changes.

tests/django_fixer_app/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from django.db import models
2+
3+
4+
class Contact(models.Model):
5+
name = models.CharField(max_length=128)
6+
phone = models.CharField(max_length=32, null=True, blank=True)
7+
8+
class Meta:
9+
app_label = "django_fixer_app"

tests/test_ast_analyzer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tests for MigrationAST helper methods in ast_analyzer.py."""
2+
23
from __future__ import annotations
34

45
import ast
@@ -13,6 +14,7 @@ def _ast(source: str) -> MigrationAST:
1314

1415
# ── module_var ────────────────────────────────────────────────────────
1516

17+
1618
def test_module_var_string():
1719
m = _ast("""
1820
revision = '001'
@@ -44,13 +46,15 @@ def downgrade(): pass
4446

4547
# ── parse error ───────────────────────────────────────────────────────
4648

49+
4750
def test_parse_error_sets_flag():
4851
m = MigrationAST("def upgrade(: pass\n", "001", "001.py")
4952
assert m._parse_error is not None
5053

5154

5255
# ── is_noop ───────────────────────────────────────────────────────────
5356

57+
5458
def test_is_noop_with_pass():
5559
m = _ast("""
5660
revision = '001'
@@ -82,6 +86,7 @@ def upgrade(): pass
8286

8387
# ── str_arg ───────────────────────────────────────────────────────────
8488

89+
8590
def test_str_arg_returns_string():
8691
tree = ast.parse("op.drop_table('users')")
8792
call = next(n for n in ast.walk(tree) if isinstance(n, ast.Call))
@@ -102,6 +107,7 @@ def test_str_arg_non_constant_returns_none():
102107

103108
# ── has_kwarg ─────────────────────────────────────────────────────────
104109

110+
105111
def test_has_kwarg_true():
106112
tree = ast.parse("op.alter_column('t', 'c', nullable=False)")
107113
call = next(n for n in ast.walk(tree) if isinstance(n, ast.Call))
@@ -116,6 +122,7 @@ def test_has_kwarg_false():
116122

117123
# ── kwarg_str ─────────────────────────────────────────────────────────
118124

125+
119126
def test_kwarg_str_returns_value():
120127
tree = ast.parse("op.create_index('ix', 'users', schema='public')")
121128
call = next(n for n in ast.walk(tree) if isinstance(n, ast.Call))
@@ -130,6 +137,7 @@ def test_kwarg_str_missing_returns_none():
130137

131138
# ── kwarg_bool ────────────────────────────────────────────────────────
132139

140+
133141
def test_kwarg_bool_true():
134142
tree = ast.parse("op.alter_column('t', 'c', nullable=True)")
135143
call = next(n for n in ast.walk(tree) if isinstance(n, ast.Call))
@@ -156,6 +164,7 @@ def test_kwarg_bool_missing_returns_none():
156164

157165
# ── sql_content ───────────────────────────────────────────────────────
158166

167+
159168
def test_sql_content_plain_string():
160169
tree = ast.parse("op.execute('SELECT 1')")
161170
call = next(n for n in ast.walk(tree) if isinstance(n, ast.Call))
@@ -177,20 +186,17 @@ def test_sql_content_no_args():
177186

178187
# ── find_column_calls ─────────────────────────────────────────────────
179188

189+
180190
def test_find_column_calls_sa_column():
181-
tree = ast.parse(
182-
"op.create_table('t', sa.Column('id', sa.Integer, primary_key=True))"
183-
)
191+
tree = ast.parse("op.create_table('t', sa.Column('id', sa.Integer, primary_key=True))")
184192
calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
185193
outer = calls[0]
186194
cols = MigrationAST.find_column_calls(outer)
187195
assert len(cols) >= 1
188196

189197

190198
def test_find_column_calls_bare_column():
191-
tree = ast.parse(
192-
"op.create_table('t', Column('id', Integer, primary_key=True))"
193-
)
199+
tree = ast.parse("op.create_table('t', Column('id', Integer, primary_key=True))")
194200
calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
195201
outer = calls[0]
196202
cols = MigrationAST.find_column_calls(outer)
@@ -199,6 +205,7 @@ def test_find_column_calls_bare_column():
199205

200206
# ── upgrade_methods / downgrade_methods ───────────────────────────────
201207

208+
202209
def test_upgrade_methods_returns_set():
203210
m = _ast("""
204211
revision = '001'
@@ -216,6 +223,7 @@ def downgrade():
216223

217224
# ── nested function not counted ───────────────────────────────────────
218225

226+
219227
def test_nested_function_calls_not_attributed_to_upgrade(tmp_path):
220228
"""Calls inside a nested def inside upgrade() should not appear in upgrade_calls."""
221229
m = _ast("""

0 commit comments

Comments
 (0)