Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "squawk-alembic"
version = "0.3.1"
version = "0.3.2"
description = "Pre-commit hook to lint Alembic migration SQL with squawk"
packages = [{include = "squawk_alembic"}]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion squawk_alembic/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.1"
__version__ = "0.3.2"
16 changes: 10 additions & 6 deletions squawk_alembic/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
import sys
import tempfile
from ast import Assign, Constant, Name, Tuple, iter_child_nodes, parse
from ast import AnnAssign, Assign, Constant, Name, Tuple, iter_child_nodes, parse
from configparser import ConfigParser, NoOptionError, NoSectionError
from pathlib import Path

Expand Down Expand Up @@ -57,12 +57,16 @@ def extract_revision_info(filepath):
down_revision = None

for node in iter_child_nodes(tree):
if not isinstance(node, Assign):
if isinstance(node, AnnAssign):
if not isinstance(node.target, Name) or node.value is None:
continue
name = node.target.id
elif isinstance(node, Assign):
if len(node.targets) != 1 or not isinstance(node.targets[0], Name):
continue
name = node.targets[0].id
else:
continue
if len(node.targets) != 1 or not isinstance(node.targets[0], Name):
continue

name = node.targets[0].id
if name == "revision":
if isinstance(node.value, Constant) and isinstance(node.value.value, str):
revision = node.value.value
Expand Down
57 changes: 57 additions & 0 deletions tests/test_revision_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,60 @@ def upgrade():
pass
""")
assert extract_revision_info(path) is None


def test_annotated_assignment(migration_file):
path = migration_file("""
from typing import Sequence, Union

revision: str = 'abc123'
down_revision: Union[str, None] = 'def456'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

def upgrade():
pass
""")
info = extract_revision_info(path)
assert info is not None
assert info.revision == "abc123"
assert info.down_revision == "def456"
assert info.is_merge is False


def test_annotated_first_migration_down_revision_is_none(migration_file):
path = migration_file("""
from typing import Sequence, Union

revision: str = 'abc123'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

def upgrade():
pass
""")
info = extract_revision_info(path)
assert info is not None
assert info.revision == "abc123"
assert info.down_revision is None
assert info.is_merge is False


def test_annotated_merge_migration(migration_file):
path = migration_file("""
from typing import Sequence, Union

revision: str = 'merge001'
down_revision: Union[str, None] = ('abc123', 'def456')
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

def upgrade():
pass
""")
info = extract_revision_info(path)
assert info is not None
assert info.revision == "merge001"
assert info.down_revision == ("abc123", "def456")
assert info.is_merge is True