|
6 | 6 |
|
7 | 7 | from .config import MRTConfig |
8 | 8 | from .core.detector import RiskWarning, analyze_migrations |
| 9 | +from .core.drift import compare_schema, describe_diff, load_metadata |
9 | 10 | from .core.runner import MigrationRunner |
10 | 11 | from .core.schema import SchemaSnapshot |
11 | 12 | from .core.seeder import SmartSeeder |
@@ -226,14 +227,112 @@ def assert_all_reversible(self, apps: list[str] | None = None) -> None: |
226 | 227 | lines.append(r.failure_summary()) |
227 | 228 | pytest.fail("Some migrations are not safely reversible:\n" + "\n".join(lines)) |
228 | 229 |
|
| 230 | + # ── schema drift ────────────────────────────────────────────────── |
| 231 | + |
| 232 | + def assert_schema_matches( |
| 233 | + self, |
| 234 | + target_metadata=None, |
| 235 | + metadata_path: str | None = None, |
| 236 | + ) -> None: |
| 237 | + """Fail if the DB schema does not match the SQLAlchemy model definitions. |
| 238 | +
|
| 239 | + For Django mode, delegates to ``manage.py makemigrations --check``. |
| 240 | +
|
| 241 | + Args: |
| 242 | + target_metadata: A SQLAlchemy ``MetaData`` instance (or declarative |
| 243 | + ``Base``) to compare against. When omitted, falls back to |
| 244 | + ``MRTConfig.target_metadata`` (an import-path string). |
| 245 | + metadata_path: Import path override, e.g. ``"myapp.models:Base"``. |
| 246 | + Takes precedence over ``MRTConfig.target_metadata``. |
| 247 | + """ |
| 248 | + if self._django_mode: |
| 249 | + self._assert_django_no_drift() |
| 250 | + return |
| 251 | + |
| 252 | + if target_metadata is None: |
| 253 | + path = metadata_path or self._config.target_metadata |
| 254 | + if path is None: |
| 255 | + raise ValueError( |
| 256 | + "assert_schema_matches() requires either a target_metadata argument " |
| 257 | + "or MRTConfig(target_metadata='myapp.models:Base')." |
| 258 | + ) |
| 259 | + target_metadata = load_metadata(path) |
| 260 | + |
| 261 | + diffs = compare_schema(self._runner.engine, target_metadata) |
| 262 | + if diffs: |
| 263 | + lines = [f" {describe_diff(d)}" for d in diffs] |
| 264 | + pytest.fail( |
| 265 | + f"Schema drift detected ({len(diffs)} difference(s)):\n" + "\n".join(lines) |
| 266 | + ) |
| 267 | + |
| 268 | + def _assert_django_no_drift(self) -> None: |
| 269 | + from io import StringIO |
| 270 | + |
| 271 | + from django.core.management import call_command |
| 272 | + |
| 273 | + out = StringIO() |
| 274 | + try: |
| 275 | + call_command("makemigrations", "--check", "--dry-run", stdout=out, stderr=out) |
| 276 | + except SystemExit as exc: |
| 277 | + if exc.code != 0: |
| 278 | + pytest.fail( |
| 279 | + "Schema drift: model changes detected that don't have migrations.\n" |
| 280 | + "Run `python manage.py makemigrations` to generate them." |
| 281 | + ) |
| 282 | + except Exception as exc: |
| 283 | + pytest.fail( |
| 284 | + f"assert_schema_matches() failed while checking Django migrations: {exc}\n" |
| 285 | + "Check that DJANGO_SETTINGS_MODULE is set correctly and all models can be imported." |
| 286 | + ) |
| 287 | + |
229 | 288 | def reset(self) -> None: |
230 | 289 | self._seeder.reset() |
231 | 290 |
|
232 | 291 |
|
| 292 | +def pytest_addoption(parser: pytest.Parser) -> None: |
| 293 | + parser.addini( |
| 294 | + "mrt_default_tests", |
| 295 | + help="Set to 'false' to disable auto-collected built-in MRT tests.", |
| 296 | + default="true", |
| 297 | + ) |
| 298 | + |
| 299 | + |
233 | 300 | def pytest_configure(config: pytest.Config) -> None: |
234 | 301 | config.addinivalue_line("markers", "mrt: migration rollback test") |
235 | 302 |
|
236 | 303 |
|
| 304 | +def pytest_collection_modifyitems( |
| 305 | + session: pytest.Session, |
| 306 | + config: pytest.Config, |
| 307 | + items: list[pytest.Item], |
| 308 | +) -> None: |
| 309 | + """Prepend built-in default tests when MRTConfig is registered.""" |
| 310 | + if getattr(config, "_mrt_config", None) is None: |
| 311 | + return |
| 312 | + if config.getini("mrt_default_tests") == "false": |
| 313 | + return |
| 314 | + |
| 315 | + from pathlib import Path |
| 316 | + |
| 317 | + try: |
| 318 | + from _pytest.python import Module as _PytestModule |
| 319 | + |
| 320 | + import pytest_mrt.default_tests as _dt |
| 321 | + |
| 322 | + dt_path = Path(_dt.__file__) |
| 323 | + module = _PytestModule.from_parent(session, path=dt_path) |
| 324 | + new_items: list[pytest.Item] = [i for i in module.collect() if isinstance(i, pytest.Item)] |
| 325 | + items[:0] = new_items |
| 326 | + except Exception as _exc: |
| 327 | + import warnings |
| 328 | + |
| 329 | + warnings.warn( |
| 330 | + f"pytest-mrt: failed to inject built-in default tests: {_exc}\n" |
| 331 | + "Set mrt_default_tests = 'false' in pytest.ini to suppress this warning.", |
| 332 | + stacklevel=2, |
| 333 | + ) |
| 334 | + |
| 335 | + |
237 | 336 | @pytest.fixture |
238 | 337 | def mrt(request: pytest.FixtureRequest) -> Iterator[MRTFixture]: |
239 | 338 | cfg: MRTConfig = getattr(request.config, "_mrt_config", None) or MRTConfig() |
|
0 commit comments