Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 117 additions & 66 deletions tortoise/migrations/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,72 +53,17 @@ async def migrate(
direction: str = "both",
progress: Callable[[str, str, str], object] | None = None,
) -> None:
self._logger.debug("Building migration graph")
await self.loader.build_graph()

# Create a non-atomic schema editor for recorder operations only
recorder_schema_editor = self._schema_editor(atomic=False)
self._logger.debug("Ensuring migration schema")
await self.recorder.ensure_schema(recorder_schema_editor)

self._logger.debug("Loading applied migrations")
applied = set(await self.recorder.applied_migrations())

self._logger.debug("Building migration plan")
plan = self._migration_plan(targets, applied, self.loader.graph)
self._validate_plan_direction(plan, direction)

state_cache_by_key: dict[MigrationKey, State] | None = None
if any(step.backward for step in plan):
self._logger.debug("Building rollback state cache")
state_cache_by_key = await self._project_state_cache(applied)

state_cache: State | None = None
for step in plan:
key = MigrationKey(app_label=step.migration.app_label, name=step.migration.name)
if step.backward:
if state_cache_by_key is not None:
state_before = state_cache_by_key[key]
else:
state_before = await self._project_state(applied, upto=key)
if not fake:
self._emit(progress, "rollback_start", key)
schema_editor = self._schema_editor(atomic=step.migration.atomic)
if schema_editor.atomic_migration:
async with in_transaction(self.connection.connection_name) as txn_client:
schema_editor.client = txn_client
await step.migration.unapply(
state_before, dry_run=dry_run, schema_editor=schema_editor
)
else:
await step.migration.unapply(
state_before, dry_run=dry_run, schema_editor=schema_editor
)
self._emit(progress, "rollback_done", key)
if not dry_run:
await self.recorder.record_unapplied(key.app_label, key.name)
applied.discard(key)
state_cache = None
else:
if state_cache is None:
state_cache = await self._project_state(applied)
if not fake:
self._emit(progress, "apply_start", key)
schema_editor = self._schema_editor(atomic=step.migration.atomic)
if schema_editor.atomic_migration:
async with in_transaction(self.connection.connection_name) as txn_client:
schema_editor.client = txn_client
await step.migration.apply(
state_cache, dry_run=dry_run, schema_editor=schema_editor
)
else:
await step.migration.apply(
state_cache, dry_run=dry_run, schema_editor=schema_editor
)
self._emit(progress, "apply_done", key)
if not dry_run:
await self.recorder.record_applied(key.app_label, key.name)
applied.add(key)
plan, applied, state_cache_by_key = await self._prepare_migration_run(
targets, direction
)
await self._run_plan(
plan,
applied=applied,
state_cache_by_key=state_cache_by_key,
fake=fake,
dry_run=dry_run,
progress=progress,
)

async def plan(self, targets: Iterable[MigrationTarget] | None = None) -> list[PlanStep]:
await self.loader.build_graph()
Expand Down Expand Up @@ -193,6 +138,112 @@ async def collect_sql(

return editor.collected_sql

async def _execute_plan_step(
self,
step: PlanStep,
*,
applied: set[MigrationKey],
state_cache: State | None,
state_cache_by_key: dict[MigrationKey, State] | None,
fake: bool,
dry_run: bool,
progress: Callable[[str, str, str], object] | None,
) -> State | None:
key = MigrationKey(app_label=step.migration.app_label, name=step.migration.name)
if step.backward:
if state_cache_by_key is not None:
state_before = state_cache_by_key[key]
else:
state_before = await self._project_state(applied, upto=key)
if not fake:
self._emit(progress, "rollback_start", key)
schema_editor = self._schema_editor(atomic=step.migration.atomic)
if schema_editor.atomic_migration:
async with in_transaction(self.connection.connection_name) as txn_client:
schema_editor.client = txn_client
await step.migration.unapply(
state_before, dry_run=dry_run, schema_editor=schema_editor
)
else:
await step.migration.unapply(
state_before, dry_run=dry_run, schema_editor=schema_editor
)
self._emit(progress, "rollback_done", key)
if not dry_run:
await self.recorder.record_unapplied(key.app_label, key.name)
applied.discard(key)
return None

if state_cache is None:
state_cache = await self._project_state(applied)
if not fake:
self._emit(progress, "apply_start", key)
schema_editor = self._schema_editor(atomic=step.migration.atomic)
if schema_editor.atomic_migration:
async with in_transaction(self.connection.connection_name) as txn_client:
schema_editor.client = txn_client
await step.migration.apply(
state_cache, dry_run=dry_run, schema_editor=schema_editor
)
else:
await step.migration.apply(
state_cache, dry_run=dry_run, schema_editor=schema_editor
)
self._emit(progress, "apply_done", key)
if not dry_run:
await self.recorder.record_applied(key.app_label, key.name)
applied.add(key)
return state_cache

async def _prepare_migration_run(
self,
targets: Iterable[MigrationTarget] | None,
direction: str,
) -> tuple[list[PlanStep], set[MigrationKey], dict[MigrationKey, State] | None]:
self._logger.debug("Building migration graph")
await self.loader.build_graph()

# Create a non-atomic schema editor for recorder operations only
recorder_schema_editor = self._schema_editor(atomic=False)
self._logger.debug("Ensuring migration schema")
await self.recorder.ensure_schema(recorder_schema_editor)

self._logger.debug("Loading applied migrations")
applied = set(await self.recorder.applied_migrations())

self._logger.debug("Building migration plan")
plan = self._migration_plan(targets, applied, self.loader.graph)
self._validate_plan_direction(plan, direction)

state_cache_by_key: dict[MigrationKey, State] | None = None
if any(step.backward for step in plan):
self._logger.debug("Building rollback state cache")
state_cache_by_key = await self._project_state_cache(applied)

return plan, applied, state_cache_by_key

async def _run_plan(
self,
plan: list[PlanStep],
*,
applied: set[MigrationKey],
state_cache_by_key: dict[MigrationKey, State] | None,
fake: bool,
dry_run: bool,
progress: Callable[[str, str, str], object] | None,
) -> None:
state_cache: State | None = None
for step in plan:
state_cache = await self._execute_plan_step(
step,
applied=applied,
state_cache=state_cache,
state_cache_by_key=state_cache_by_key,
fake=fake,
dry_run=dry_run,
progress=progress,
)

def _resolve_migration_key(self, app_label: str, migration_name: str) -> MigrationKey:
"""Resolve a migration name to a MigrationKey, supporting prefix matching."""
exact_key = MigrationKey(app_label=app_label, name=migration_name)
Expand Down