Skip to content

Commit c998ece

Browse files
fix: handle annotated assignments (AnnAssign) in migration files
Newer Alembic versions generate `revision: str = '...'` instead of `revision = '...'`. The AST parser only handled Assign nodes, silently skipping every migration that used the annotated form. Closes #4
1 parent 2ff2940 commit c998ece

2 files changed

Lines changed: 67 additions & 6 deletions

File tree

squawk_alembic/hook.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import subprocess
77
import sys
88
import tempfile
9-
from ast import Assign, Constant, Name, Tuple, iter_child_nodes, parse
9+
from ast import AnnAssign, Assign, Constant, Name, Tuple, iter_child_nodes, parse
1010
from configparser import ConfigParser, NoOptionError, NoSectionError
1111
from pathlib import Path
1212

@@ -57,12 +57,16 @@ def extract_revision_info(filepath):
5757
down_revision = None
5858

5959
for node in iter_child_nodes(tree):
60-
if not isinstance(node, Assign):
60+
if isinstance(node, AnnAssign):
61+
if not isinstance(node.target, Name) or node.value is None:
62+
continue
63+
name = node.target.id
64+
elif isinstance(node, Assign):
65+
if len(node.targets) != 1 or not isinstance(node.targets[0], Name):
66+
continue
67+
name = node.targets[0].id
68+
else:
6169
continue
62-
if len(node.targets) != 1 or not isinstance(node.targets[0], Name):
63-
continue
64-
65-
name = node.targets[0].id
6670
if name == "revision":
6771
if isinstance(node.value, Constant) and isinstance(node.value.value, str):
6872
revision = node.value.value

tests/test_revision_info.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,60 @@ def upgrade():
8585
pass
8686
""")
8787
assert extract_revision_info(path) is None
88+
89+
90+
def test_annotated_assignment(migration_file):
91+
path = migration_file("""
92+
from typing import Sequence, Union
93+
94+
revision: str = 'abc123'
95+
down_revision: Union[str, None] = 'def456'
96+
branch_labels: Union[str, Sequence[str], None] = None
97+
depends_on: Union[str, Sequence[str], None] = None
98+
99+
def upgrade():
100+
pass
101+
""")
102+
info = extract_revision_info(path)
103+
assert info is not None
104+
assert info.revision == "abc123"
105+
assert info.down_revision == "def456"
106+
assert info.is_merge is False
107+
108+
109+
def test_annotated_first_migration_down_revision_is_none(migration_file):
110+
path = migration_file("""
111+
from typing import Sequence, Union
112+
113+
revision: str = 'abc123'
114+
down_revision: Union[str, None] = None
115+
branch_labels: Union[str, Sequence[str], None] = None
116+
depends_on: Union[str, Sequence[str], None] = None
117+
118+
def upgrade():
119+
pass
120+
""")
121+
info = extract_revision_info(path)
122+
assert info is not None
123+
assert info.revision == "abc123"
124+
assert info.down_revision is None
125+
assert info.is_merge is False
126+
127+
128+
def test_annotated_merge_migration(migration_file):
129+
path = migration_file("""
130+
from typing import Sequence, Union
131+
132+
revision: str = 'merge001'
133+
down_revision: Union[str, None] = ('abc123', 'def456')
134+
branch_labels: Union[str, Sequence[str], None] = None
135+
depends_on: Union[str, Sequence[str], None] = None
136+
137+
def upgrade():
138+
pass
139+
""")
140+
info = extract_revision_info(path)
141+
assert info is not None
142+
assert info.revision == "merge001"
143+
assert info.down_revision == ("abc123", "def456")
144+
assert info.is_merge is True

0 commit comments

Comments
 (0)