Skip to content

Commit 72df6bc

Browse files
authored
feat: built-in default tests + schema drift detection (v1.1.0) (#27)
* feat: add built-in default tests and schema drift detection (v1.1.0) - Auto-collect 5 built-in tests when MRTConfig is registered (no imports needed). Disable with mrt_default_tests = "false" in pytest.ini. - Add MRTFixture.assert_schema_matches() — fails on SQLAlchemy model/migration drift. Django mode delegates to manage.py makemigrations --check. - Add mrt drift CLI command for schema drift reporting. - Add MRTConfig.target_metadata field for import-path based metadata resolution. - Fix get_versions_dir() to return actual versions/ path, not script root. * fix: resolve 6 code-review bugs (drift safety, bare excepts, DB state) * style: apply ruff format * fix: narrow Item | Collector to Item in pytest_collection_modifyitems
1 parent 488e5e1 commit 72df6bc

9 files changed

Lines changed: 326 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ Versioning: [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
77

88
---
99

10+
## [1.1.0] — 2026-06-07
11+
12+
### Added
13+
- **Built-in default tests**: Five tests are now auto-collected when `MRTConfig` is registered in `conftest.py` — no imports required. Tests: `test_mrt_single_head`, `test_mrt_upgrade`, `test_mrt_downgrade_base`, `test_mrt_static_no_errors`, `test_mrt_schema_matches_models`. Disable with `mrt_default_tests = "false"` in `pytest.ini` / `pyproject.toml`.
14+
- **Schema drift detection**: `MRTFixture.assert_schema_matches()` — fails if the DB schema after running all migrations does not match the SQLAlchemy model definitions. Accepts a `MetaData` instance or an import-path string (`"myapp.models:Base"`). Django mode delegates to `manage.py makemigrations --check`.
15+
- **`mrt drift` CLI command**: `mrt drift myapp.models:Base --config alembic.ini --db-url sqlite:///test.db` — runs migrations to head, compares schema against models, prints a diff table, exits 1 on drift.
16+
- **`MRTConfig.target_metadata`**: New field (`str | None`) — import path for the SQLAlchemy `Base` or `MetaData` used by `assert_schema_matches()` and `test_mrt_schema_matches_models`.
17+
18+
---
19+
1020
## [1.0.1] — 2026-06-06
1121

1222
### Fixed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "pytest-mrt"
7-
version = "1.0.1"
7+
version = "1.1.0"
88
description = "Catch database migration rollback failures before they reach production"
99
readme = "README.md"
1010
license = { file = "LICENSE" }

pytest_mrt/cli.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,77 @@ def check(
126126
raise typer.Exit(1)
127127

128128

129+
# ──────────────────────────────────────────────
130+
# mrt drift
131+
# ──────────────────────────────────────────────
132+
133+
134+
@app.command("drift")
135+
def drift(
136+
metadata: str = typer.Argument(
137+
help="SQLAlchemy metadata import path, e.g. 'myapp.models:Base'"
138+
),
139+
alembic_ini: str = typer.Option("alembic.ini", "--config", "-c", help="Path to alembic.ini"),
140+
db_url: str = typer.Option("", "--db-url", help="Database URL (overrides alembic.ini)"),
141+
) -> None:
142+
"""Check if SQLAlchemy model definitions match the current migration state.
143+
144+
Compares the live DB schema (after running all migrations) against the
145+
SQLAlchemy models you point at. Exits 0 if clean, 1 if drift is found.
146+
147+
Example:
148+
149+
mrt drift myapp.models:Base --config alembic.ini --db-url sqlite:///test.db
150+
"""
151+
from .core.drift import compare_schema, describe_diff, load_metadata
152+
from .core.runner import MigrationRunner
153+
154+
# Load metadata
155+
try:
156+
target_metadata = load_metadata(metadata)
157+
except (ValueError, ImportError, AttributeError) as exc:
158+
console.print(f"[red]Error loading metadata:[/red] {exc}")
159+
raise typer.Exit(1)
160+
161+
# Build runner (needs alembic.ini + db_url)
162+
if not Path(alembic_ini).exists():
163+
console.print(f"[red]alembic.ini not found:[/red] {alembic_ini}")
164+
raise typer.Exit(1)
165+
166+
try:
167+
runner = MigrationRunner(alembic_ini, db_url)
168+
except Exception as exc:
169+
console.print(f"[red]Failed to connect:[/red] {exc}")
170+
raise typer.Exit(1)
171+
172+
console.print("[dim]Upgrading to head...[/dim]")
173+
try:
174+
runner.upgrade("head")
175+
except Exception as exc:
176+
console.print(f"[red]Migration failed:[/red] {exc}")
177+
raise typer.Exit(1)
178+
179+
console.print("[dim]Comparing schema...[/dim]")
180+
diffs = compare_schema(runner.engine, target_metadata)
181+
182+
if not diffs:
183+
console.print("[green]✓ No schema drift — models match migrations.[/green]")
184+
raise typer.Exit(0)
185+
186+
table = Table(box=box.ROUNDED, title="Schema Drift", show_lines=True)
187+
table.add_column("#", style="dim", no_wrap=True)
188+
table.add_column("Difference")
189+
190+
for i, d in enumerate(diffs, 1):
191+
table.add_row(str(i), describe_diff(d))
192+
193+
console.print(table)
194+
console.print(
195+
f"\n[red]{len(diffs)} difference(s) found.[/red] Run migrations or update your models."
196+
)
197+
raise typer.Exit(1)
198+
199+
129200
# ──────────────────────────────────────────────
130201
# mrt init
131202
# ──────────────────────────────────────────────

pytest_mrt/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,9 @@ class MRTConfig:
7676
# Model used by `mrt explain`. Defaults to DEFAULT_EXPLAIN_MODEL.
7777
# Override to use a different Claude model, e.g. "claude-3-5-haiku-latest".
7878
explain_model: str = DEFAULT_EXPLAIN_MODEL
79+
80+
# Import path for the SQLAlchemy declarative Base (or MetaData) used by
81+
# assert_schema_matches() and the built-in test_mrt_schema_matches_models test.
82+
# Format: "myapp.models:Base" or "myapp.models:Base.metadata"
83+
# Example: target_metadata="myproject.db.models:Base"
84+
target_metadata: str | None = None

pytest_mrt/core/drift.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
4+
def load_metadata(metadata_path: str):
5+
"""Import SQLAlchemy metadata from a dotted path like 'myapp.models:Base'.
6+
7+
The format is 'module.path:AttributeName'. Dotted attribute traversal after
8+
the colon is not supported — use 'myapp.models:Base' (not 'myapp.models:Base.metadata').
9+
Both a declarative Base class and a MetaData instance are accepted.
10+
"""
11+
if ":" not in metadata_path:
12+
raise ValueError(
13+
f"Invalid metadata path '{metadata_path}'. Use the form 'myapp.models:Base'."
14+
)
15+
module_path, attr = metadata_path.rsplit(":", 1)
16+
if "." in attr:
17+
raise ValueError(
18+
f"Dotted attribute '{attr}' is not supported after the colon. "
19+
f"Use 'myapp.models:Base' instead of 'myapp.models:Base.metadata'."
20+
)
21+
import importlib
22+
23+
mod = importlib.import_module(module_path)
24+
obj = getattr(mod, attr)
25+
# Accept either a declarative Base class or a MetaData instance directly.
26+
return getattr(obj, "metadata", obj)
27+
28+
29+
def compare_schema(engine, target_metadata) -> list:
30+
"""Return alembic autogenerate diffs between DB schema and target_metadata."""
31+
from alembic.autogenerate import compare_metadata
32+
from alembic.runtime.migration import MigrationContext
33+
34+
with engine.connect() as conn:
35+
ctx = MigrationContext.configure(conn)
36+
return compare_metadata(ctx, target_metadata)
37+
38+
39+
def describe_diff(diff) -> str:
40+
"""Format a single alembic autogenerate diff tuple into a human-readable string."""
41+
if not isinstance(diff, tuple):
42+
return repr(diff)
43+
kind = diff[0]
44+
if kind == "add_table":
45+
return f"add table '{diff[1].name}' (in models, missing from DB)"
46+
if kind == "remove_table":
47+
return f"remove table '{diff[1].name}' (in DB, missing from models)"
48+
if kind == "add_column":
49+
return f"add column '{diff[2]}.{diff[3].name}' (in models, missing from DB)"
50+
if kind == "remove_column":
51+
return f"remove column '{diff[2]}.{diff[3].name}' (in DB, missing from models)"
52+
if kind == "modify_type":
53+
return f"column '{diff[2]}.{diff[3]}' type mismatch: models={diff[5]!r}, DB={diff[4]!r}"
54+
if kind == "modify_nullable":
55+
return f"column '{diff[2]}.{diff[3]}' nullable mismatch: models={diff[5]}, DB={diff[4]}"
56+
return repr(diff)

pytest_mrt/core/runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,8 @@ def get_revisions(self) -> list:
7272

7373
def get_versions_dir(self) -> str:
7474
script = ScriptDirectory.from_config(self.alembic_cfg)
75-
return script.dir
75+
# script.versions returns the single version_locations path (where .py migration files live).
76+
# script.dir is the script_location root (contains env.py, script.py.mako, etc.) — not what we want.
77+
# script.versions raises CommandError when multiple version_locations are configured; surface
78+
# that instead of silently falling back to script.dir (which contains non-migration files).
79+
return script.versions

pytest_mrt/default_tests.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
Built-in default tests for pytest-mrt.
3+
4+
These are automatically collected when MRTConfig is registered in conftest.py.
5+
To disable auto-collection, add the following to pyproject.toml or pytest.ini:
6+
7+
[tool.pytest.ini_options]
8+
mrt_default_tests = "false"
9+
10+
You can also import individual tests explicitly:
11+
12+
from pytest_mrt.default_tests import test_mrt_upgrade, test_mrt_downgrade_base
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import pytest
18+
19+
20+
def test_mrt_single_head(mrt) -> None:
21+
"""Exactly one head revision exists in the migration chain."""
22+
if mrt._django_mode:
23+
pytest.skip("single-head check not applicable to Django mode")
24+
from alembic.script import ScriptDirectory
25+
26+
script = ScriptDirectory.from_config(mrt._runner.alembic_cfg)
27+
heads = script.get_heads()
28+
assert len(heads) == 1, (
29+
f"Expected a single head revision, found {len(heads)}: {heads}.\n"
30+
"Run `alembic merge heads` to create a merge migration."
31+
)
32+
33+
34+
def test_mrt_upgrade(mrt) -> None:
35+
"""Migration chain upgrades to head without error."""
36+
if mrt._django_mode:
37+
pytest.skip("use test_mrt_all_reversible for Django mode")
38+
mrt.upgrade("head")
39+
40+
41+
def test_mrt_downgrade_base(mrt) -> None:
42+
"""Migration chain downgrades to base without error."""
43+
if mrt._django_mode:
44+
pytest.skip("use test_mrt_all_reversible for Django mode")
45+
mrt.upgrade("head")
46+
mrt._runner.downgrade_base()
47+
try:
48+
mrt.upgrade("head") # restore to head so subsequent tests start clean
49+
except Exception as exc:
50+
pytest.fail(
51+
f"Migration chain failed to re-upgrade to head after downgrade_base: {exc}\n"
52+
"DB state is now at 'base'. Fix the upgrade() failure before continuing."
53+
)
54+
55+
56+
def test_mrt_static_no_errors(mrt) -> None:
57+
"""No static analysis errors found in migration files."""
58+
mrt.assert_no_static_errors()
59+
60+
61+
def test_mrt_schema_matches_models(mrt) -> None:
62+
"""Model definitions match the migration state (no schema drift).
63+
64+
Requires ``MRTConfig(target_metadata='myapp.models:Base')`` to be set.
65+
Skipped automatically when target_metadata is not configured.
66+
"""
67+
if mrt._config.target_metadata is None:
68+
pytest.skip(
69+
"target_metadata not configured — skipping schema drift check.\n"
70+
"Set MRTConfig(target_metadata='myapp.models:Base') to enable."
71+
)
72+
# Ensure migrations are applied before comparing schema.
73+
# This makes the test safe to run in isolation (e.g. pytest -k test_mrt_schema_matches_models).
74+
if not mrt._django_mode:
75+
mrt.upgrade("head")
76+
mrt.assert_schema_matches()

pytest_mrt/plugin.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .config import MRTConfig
88
from .core.detector import RiskWarning, analyze_migrations
9+
from .core.drift import compare_schema, describe_diff, load_metadata
910
from .core.runner import MigrationRunner
1011
from .core.schema import SchemaSnapshot
1112
from .core.seeder import SmartSeeder
@@ -226,14 +227,112 @@ def assert_all_reversible(self, apps: list[str] | None = None) -> None:
226227
lines.append(r.failure_summary())
227228
pytest.fail("Some migrations are not safely reversible:\n" + "\n".join(lines))
228229

230+
# ── schema drift ──────────────────────────────────────────────────
231+
232+
def assert_schema_matches(
233+
self,
234+
target_metadata=None,
235+
metadata_path: str | None = None,
236+
) -> None:
237+
"""Fail if the DB schema does not match the SQLAlchemy model definitions.
238+
239+
For Django mode, delegates to ``manage.py makemigrations --check``.
240+
241+
Args:
242+
target_metadata: A SQLAlchemy ``MetaData`` instance (or declarative
243+
``Base``) to compare against. When omitted, falls back to
244+
``MRTConfig.target_metadata`` (an import-path string).
245+
metadata_path: Import path override, e.g. ``"myapp.models:Base"``.
246+
Takes precedence over ``MRTConfig.target_metadata``.
247+
"""
248+
if self._django_mode:
249+
self._assert_django_no_drift()
250+
return
251+
252+
if target_metadata is None:
253+
path = metadata_path or self._config.target_metadata
254+
if path is None:
255+
raise ValueError(
256+
"assert_schema_matches() requires either a target_metadata argument "
257+
"or MRTConfig(target_metadata='myapp.models:Base')."
258+
)
259+
target_metadata = load_metadata(path)
260+
261+
diffs = compare_schema(self._runner.engine, target_metadata)
262+
if diffs:
263+
lines = [f" {describe_diff(d)}" for d in diffs]
264+
pytest.fail(
265+
f"Schema drift detected ({len(diffs)} difference(s)):\n" + "\n".join(lines)
266+
)
267+
268+
def _assert_django_no_drift(self) -> None:
269+
from io import StringIO
270+
271+
from django.core.management import call_command
272+
273+
out = StringIO()
274+
try:
275+
call_command("makemigrations", "--check", "--dry-run", stdout=out, stderr=out)
276+
except SystemExit as exc:
277+
if exc.code != 0:
278+
pytest.fail(
279+
"Schema drift: model changes detected that don't have migrations.\n"
280+
"Run `python manage.py makemigrations` to generate them."
281+
)
282+
except Exception as exc:
283+
pytest.fail(
284+
f"assert_schema_matches() failed while checking Django migrations: {exc}\n"
285+
"Check that DJANGO_SETTINGS_MODULE is set correctly and all models can be imported."
286+
)
287+
229288
def reset(self) -> None:
230289
self._seeder.reset()
231290

232291

292+
def pytest_addoption(parser: pytest.Parser) -> None:
293+
parser.addini(
294+
"mrt_default_tests",
295+
help="Set to 'false' to disable auto-collected built-in MRT tests.",
296+
default="true",
297+
)
298+
299+
233300
def pytest_configure(config: pytest.Config) -> None:
234301
config.addinivalue_line("markers", "mrt: migration rollback test")
235302

236303

304+
def pytest_collection_modifyitems(
305+
session: pytest.Session,
306+
config: pytest.Config,
307+
items: list[pytest.Item],
308+
) -> None:
309+
"""Prepend built-in default tests when MRTConfig is registered."""
310+
if getattr(config, "_mrt_config", None) is None:
311+
return
312+
if config.getini("mrt_default_tests") == "false":
313+
return
314+
315+
from pathlib import Path
316+
317+
try:
318+
from _pytest.python import Module as _PytestModule
319+
320+
import pytest_mrt.default_tests as _dt
321+
322+
dt_path = Path(_dt.__file__)
323+
module = _PytestModule.from_parent(session, path=dt_path)
324+
new_items: list[pytest.Item] = [i for i in module.collect() if isinstance(i, pytest.Item)]
325+
items[:0] = new_items
326+
except Exception as _exc:
327+
import warnings
328+
329+
warnings.warn(
330+
f"pytest-mrt: failed to inject built-in default tests: {_exc}\n"
331+
"Set mrt_default_tests = 'false' in pytest.ini to suppress this warning.",
332+
stacklevel=2,
333+
)
334+
335+
237336
@pytest.fixture
238337
def mrt(request: pytest.FixtureRequest) -> Iterator[MRTFixture]:
239338
cfg: MRTConfig = getattr(request.config, "_mrt_config", None) or MRTConfig()

tests/test_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,8 @@ def test_fixture_type(mrt):
523523
def test_fixture_has_config(mrt):
524524
assert mrt._config is not None
525525
""")
526+
# Disable default test injection so this test focuses on fixture wiring only.
527+
pytester.makeini("[pytest]\nmrt_default_tests = false\n")
526528
result = pytester.runpytest("-v")
527529
result.assert_outcomes(passed=2)
528530

0 commit comments

Comments
 (0)