diff --git a/tortoise/migrations/executor.py b/tortoise/migrations/executor.py index 77ffa5f17..0d228f65b 100644 --- a/tortoise/migrations/executor.py +++ b/tortoise/migrations/executor.py @@ -53,72 +53,17 @@ async def migrate( direction: str = "both", progress: Callable[[str, str, str], object] | None = None, ) -> None: - self._logger.debug("Building migration graph") - await self.loader.build_graph() - - # Create a non-atomic schema editor for recorder operations only - recorder_schema_editor = self._schema_editor(atomic=False) - self._logger.debug("Ensuring migration schema") - await self.recorder.ensure_schema(recorder_schema_editor) - - self._logger.debug("Loading applied migrations") - applied = set(await self.recorder.applied_migrations()) - - self._logger.debug("Building migration plan") - plan = self._migration_plan(targets, applied, self.loader.graph) - self._validate_plan_direction(plan, direction) - - state_cache_by_key: dict[MigrationKey, State] | None = None - if any(step.backward for step in plan): - self._logger.debug("Building rollback state cache") - state_cache_by_key = await self._project_state_cache(applied) - - state_cache: State | None = None - for step in plan: - key = MigrationKey(app_label=step.migration.app_label, name=step.migration.name) - if step.backward: - if state_cache_by_key is not None: - state_before = state_cache_by_key[key] - else: - state_before = await self._project_state(applied, upto=key) - if not fake: - self._emit(progress, "rollback_start", key) - schema_editor = self._schema_editor(atomic=step.migration.atomic) - if schema_editor.atomic_migration: - async with in_transaction(self.connection.connection_name) as txn_client: - schema_editor.client = txn_client - await step.migration.unapply( - state_before, dry_run=dry_run, schema_editor=schema_editor - ) - else: - await step.migration.unapply( - state_before, dry_run=dry_run, schema_editor=schema_editor - ) - self._emit(progress, "rollback_done", key) - if not dry_run: - await self.recorder.record_unapplied(key.app_label, key.name) - applied.discard(key) - state_cache = None - else: - if state_cache is None: - state_cache = await self._project_state(applied) - if not fake: - self._emit(progress, "apply_start", key) - schema_editor = self._schema_editor(atomic=step.migration.atomic) - if schema_editor.atomic_migration: - async with in_transaction(self.connection.connection_name) as txn_client: - schema_editor.client = txn_client - await step.migration.apply( - state_cache, dry_run=dry_run, schema_editor=schema_editor - ) - else: - await step.migration.apply( - state_cache, dry_run=dry_run, schema_editor=schema_editor - ) - self._emit(progress, "apply_done", key) - if not dry_run: - await self.recorder.record_applied(key.app_label, key.name) - applied.add(key) + plan, applied, state_cache_by_key = await self._prepare_migration_run( + targets, direction + ) + await self._run_plan( + plan, + applied=applied, + state_cache_by_key=state_cache_by_key, + fake=fake, + dry_run=dry_run, + progress=progress, + ) async def plan(self, targets: Iterable[MigrationTarget] | None = None) -> list[PlanStep]: await self.loader.build_graph() @@ -193,6 +138,112 @@ async def collect_sql( return editor.collected_sql + async def _execute_plan_step( + self, + step: PlanStep, + *, + applied: set[MigrationKey], + state_cache: State | None, + state_cache_by_key: dict[MigrationKey, State] | None, + fake: bool, + dry_run: bool, + progress: Callable[[str, str, str], object] | None, + ) -> State | None: + key = MigrationKey(app_label=step.migration.app_label, name=step.migration.name) + if step.backward: + if state_cache_by_key is not None: + state_before = state_cache_by_key[key] + else: + state_before = await self._project_state(applied, upto=key) + if not fake: + self._emit(progress, "rollback_start", key) + schema_editor = self._schema_editor(atomic=step.migration.atomic) + if schema_editor.atomic_migration: + async with in_transaction(self.connection.connection_name) as txn_client: + schema_editor.client = txn_client + await step.migration.unapply( + state_before, dry_run=dry_run, schema_editor=schema_editor + ) + else: + await step.migration.unapply( + state_before, dry_run=dry_run, schema_editor=schema_editor + ) + self._emit(progress, "rollback_done", key) + if not dry_run: + await self.recorder.record_unapplied(key.app_label, key.name) + applied.discard(key) + return None + + if state_cache is None: + state_cache = await self._project_state(applied) + if not fake: + self._emit(progress, "apply_start", key) + schema_editor = self._schema_editor(atomic=step.migration.atomic) + if schema_editor.atomic_migration: + async with in_transaction(self.connection.connection_name) as txn_client: + schema_editor.client = txn_client + await step.migration.apply( + state_cache, dry_run=dry_run, schema_editor=schema_editor + ) + else: + await step.migration.apply( + state_cache, dry_run=dry_run, schema_editor=schema_editor + ) + self._emit(progress, "apply_done", key) + if not dry_run: + await self.recorder.record_applied(key.app_label, key.name) + applied.add(key) + return state_cache + + async def _prepare_migration_run( + self, + targets: Iterable[MigrationTarget] | None, + direction: str, + ) -> tuple[list[PlanStep], set[MigrationKey], dict[MigrationKey, State] | None]: + self._logger.debug("Building migration graph") + await self.loader.build_graph() + + # Create a non-atomic schema editor for recorder operations only + recorder_schema_editor = self._schema_editor(atomic=False) + self._logger.debug("Ensuring migration schema") + await self.recorder.ensure_schema(recorder_schema_editor) + + self._logger.debug("Loading applied migrations") + applied = set(await self.recorder.applied_migrations()) + + self._logger.debug("Building migration plan") + plan = self._migration_plan(targets, applied, self.loader.graph) + self._validate_plan_direction(plan, direction) + + state_cache_by_key: dict[MigrationKey, State] | None = None + if any(step.backward for step in plan): + self._logger.debug("Building rollback state cache") + state_cache_by_key = await self._project_state_cache(applied) + + return plan, applied, state_cache_by_key + + async def _run_plan( + self, + plan: list[PlanStep], + *, + applied: set[MigrationKey], + state_cache_by_key: dict[MigrationKey, State] | None, + fake: bool, + dry_run: bool, + progress: Callable[[str, str, str], object] | None, + ) -> None: + state_cache: State | None = None + for step in plan: + state_cache = await self._execute_plan_step( + step, + applied=applied, + state_cache=state_cache, + state_cache_by_key=state_cache_by_key, + fake=fake, + dry_run=dry_run, + progress=progress, + ) + def _resolve_migration_key(self, app_label: str, migration_name: str) -> MigrationKey: """Resolve a migration name to a MigrationKey, supporting prefix matching.""" exact_key = MigrationKey(app_label=app_label, name=migration_name)