Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions tests/migrations/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def test_writer_format_create_model_basic(tmp_path: Path, monkeypatch) -> None:
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise import fields, migrations
from tortoise.migrations import operations as ops
from tortoise import fields

class Migration(migrations.Migration):
operations = [
Expand Down Expand Up @@ -107,9 +106,8 @@ def test_writer_format_rename_and_alter(tmp_path: Path, monkeypatch) -> None:
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise import fields, migrations
from tortoise.migrations import operations as ops
from tortoise import fields

class Migration(migrations.Migration):
operations = [
Expand Down Expand Up @@ -156,10 +154,9 @@ def test_writer_format_options_indexes_constraints(tmp_path: Path, monkeypatch)
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise.migrations import operations as ops
from tortoise import fields
from tortoise import fields, migrations
from tortoise.indexes import Index, PartialIndex
from tortoise.migrations import operations as ops
from tortoise.migrations.constraints import UniqueConstraint

class Migration(migrations.Migration):
Expand Down Expand Up @@ -233,10 +230,9 @@ def test_writer_renders_fk_field(tmp_path: Path, monkeypatch) -> None:
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise.migrations import operations as ops
from tortoise import fields, migrations
from tortoise.fields.base import OnDelete
from tortoise import fields
from tortoise.migrations import operations as ops

class Migration(migrations.Migration):
operations = [
Expand Down Expand Up @@ -272,10 +268,9 @@ def test_writer_excludes_fk_source_field(tmp_path: Path, monkeypatch) -> None:
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise.migrations import operations as ops
from tortoise import fields, migrations
from tortoise.fields.base import OnDelete
from tortoise import fields
from tortoise.migrations import operations as ops

class Migration(migrations.Migration):
operations = [
Expand Down Expand Up @@ -312,10 +307,9 @@ def test_writer_serializes_on_delete_enum(tmp_path: Path, monkeypatch) -> None:
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise.migrations import operations as ops
from tortoise import fields, migrations
from tortoise.fields.base import OnDelete
from tortoise import fields
from tortoise.migrations import operations as ops

class Migration(migrations.Migration):
operations = [
Expand All @@ -341,9 +335,8 @@ def test_writer_skips_missing_db_index(tmp_path: Path, monkeypatch) -> None:
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise import fields, migrations
from tortoise.migrations import operations as ops
from tortoise import fields

class Migration(migrations.Migration):
operations = [
Expand Down Expand Up @@ -464,10 +457,9 @@ def test_writer_handles_one_to_one_field(tmp_path: Path, monkeypatch) -> None:
]
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise.migrations import operations as ops
from tortoise import fields, migrations
from tortoise.fields.base import OnDelete
from tortoise import fields
from tortoise.migrations import operations as ops

class Migration(migrations.Migration):
operations = [
Expand Down Expand Up @@ -500,10 +492,9 @@ def test_writer_handles_enum_fields(tmp_path: Path, monkeypatch) -> None:
# NOT fields.IntEnumFieldInstance or fields.CharEnumFieldInstance
expected = textwrap.dedent(
"""\
from tortoise import migrations
from tortoise.migrations import operations as ops
from tests.migrations.test_writer import Role, Status
from tortoise import fields
from tortoise import fields, migrations
from tortoise.migrations import operations as ops

class Migration(migrations.Migration):
operations = [
Expand Down Expand Up @@ -533,9 +524,9 @@ def test_writer_format_runpython(tmp_path: Path, monkeypatch) -> None:
operations = [RunPython(_runpython_forward, reverse_code=_runpython_reverse, atomic=False)]
expected = textwrap.dedent(
"""\
from tests.migrations.test_writer import _runpython_forward, _runpython_reverse
from tortoise import migrations
from tortoise.migrations import operations as ops
from tests.migrations.test_writer import _runpython_forward, _runpython_reverse

class Migration(migrations.Migration):
operations = [
Expand Down
21 changes: 13 additions & 8 deletions tortoise/migrations/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,28 @@ def add_index_class(self, name: str) -> None:
def add_constraint_class(self, name: str) -> None:
self.uses_constraints.add(name)

def render(self) -> list[str]:
lines: list[str] = []
def render(self, *required_imports: str) -> list[str]:
lines: list[str] = [*required_imports]
for module in sorted(self.modules):
lines.append(f"import {module}")
for module, names in sorted(self.imports.items()):
lines.append(f"from {module} import {', '.join(sorted(names))}")
if self.uses_fields_module:
lines.append("from tortoise import fields")
from_import = "from tortoise import "
for idx, line in enumerate(lines):
if line.startswith(from_import):
imported = [i.strip() for i in line[len(from_import) :].split(",")]
lines[idx] = from_import + ", ".join(sorted(["fields", *imported]))
break
else:
lines.append("from tortoise import fields")
if self.uses_indexes:
index_names = ", ".join(sorted(self.uses_indexes))
lines.append(f"from tortoise.indexes import {index_names}")
if self.uses_constraints:
constraint_names = ", ".join(sorted(self.uses_constraints))
lines.append(f"from tortoise.migrations.constraints import {constraint_names}")
return lines
return sorted(lines)


def _resolve_import(value: Any) -> tuple[str, str, bool]:
Expand Down Expand Up @@ -288,13 +295,11 @@ def as_string(self) -> str:
for operation in self.operations:
operations.extend(self._format_operation(operation, imports, indent=" " * 8))

lines: list[str] = [
required_imports = [
"from tortoise import migrations",
"from tortoise.migrations import operations as ops",
]
extra_imports = imports.render()
if extra_imports:
lines.extend(extra_imports)
lines: list[str] = imports.render(*required_imports)
lines.extend(["", "class Migration(migrations.Migration):"])
blocks: list[list[str]] = []
if self.dependencies:
Expand Down
Loading