Skip to content

Commit 124cb60

Browse files
feat: switch from AST parsing to alembic upgrade --sql
Replace extract_sql/AST extraction with generate_sql that calls alembic upgrade --sql for complete DDL generation. This gives squawk visibility into ORM operations like op.create_index and op.create_table that were previously invisible. Extend autocommit checker to flag op.create_index/op.drop_index with postgresql_concurrently=True outside autocommit_block.
1 parent b1d1a1e commit 124cb60

9 files changed

Lines changed: 946 additions & 503 deletions

File tree

README.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
A [pre-commit](https://pre-commit.com/) hook that lints SQL in [Alembic](https://alembic.sqlalchemy.org/) migrations using [squawk](https://squawkhq.com/), a PostgreSQL migration linter.
44

5-
Squawk operates on raw SQL files, but Alembic migrations are Python. This hook bridges the gap by extracting SQL from `op.execute()` calls in migration files and passing it to squawk for analysis.
5+
Squawk operates on raw SQL files, but Alembic migrations are Python. This hook bridges the gap by generating DDL via `alembic upgrade --sql` (offline mode) and passing the complete SQL output to squawk for analysis. This captures all SQL statements a migration produces, including ORM operations like `op.create_index()`, `op.create_table()`, and `op.alter_column()`.
6+
7+
The hook also checks that concurrent index operations (`CONCURRENTLY` in `op.execute()` or `postgresql_concurrently=True` in `op.create_index()` / `op.drop_index()`) are wrapped in `autocommit_block()`.
68

79
## Usage
810

@@ -11,25 +13,24 @@ Add the following to your `.pre-commit-config.yaml`:
1113
```yaml
1214
repos:
1315
- repo: https://github.com/kintsugi-tax/kintsugi-squawk
14-
rev: v0.1.0
16+
rev: v0.2.0
1517
hooks:
1618
- id: squawk-alembic
1719
```
1820
19-
No additional configuration is required. The hook auto-detects your migrations directory by reading `script_location` from `alembic.ini`.
21+
No additional configuration is required. The hook auto-detects your migrations directory by reading `script_location` from `alembic.ini`. The consumer's `alembic` must be available on PATH (the hook calls it via subprocess).
2022

2123
## How It Works
2224

2325
When pre-commit runs, the hook:
2426

2527
1. Parses `alembic.ini` to find the migrations `versions/` directory
2628
2. Filters staged files to only those under that directory
27-
3. Extracts SQL strings from `op.execute()` calls using Python's AST parser
28-
4. Pipes the extracted SQL to squawk for linting
29-
30-
The hook handles common patterns including `op.execute("...")`, `op.execute(sa.text("..."))`, triple-quoted strings, and implicit string concatenation.
29+
3. Checks for concurrent operations outside `autocommit_block()`
30+
4. Runs `alembic upgrade --sql` to generate the complete DDL for each migration
31+
5. Pipes the generated SQL to squawk for linting
3132

32-
ORM-level operations like `op.add_column()` and `op.create_table()` are not linted, since they don't contain raw SQL. These produce safe, predictable DDL that squawk is less likely to flag.
33+
Merge migrations (where `down_revision` is a tuple) are skipped since they produce no DDL.
3334

3435
## Squawk Configuration
3536

@@ -41,6 +42,7 @@ Squawk reads its configuration from `.squawk.toml` in the consumer repo root. Se
4142

4243
* Python (version 3.12)
4344
* Poetry
45+
* squawk-cli (`pip install squawk-cli`)
4446

4547
**Steps:**
4648

poetry.lock

Lines changed: 307 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python = ">=3.10"
1313
squawk-cli = "*"
1414

1515
[tool.poetry.group.dev.dependencies]
16+
alembic = "*"
1617
pre-commit = "*"
1718
pytest = "*"
1819
ruff = "*"

squawk_alembic/hook.py

Lines changed: 121 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
"""Pre-commit hook that extracts SQL from Alembic migrations and lints with squawk."""
1+
"""Pre-commit hook that generates DDL via alembic upgrade --sql and lints with squawk."""
22

33
import ast
44
import configparser
5+
import os
56
import subprocess
67
import sys
78
import tempfile
@@ -31,50 +32,98 @@ def find_migrations_path():
3132
return None
3233

3334

34-
def extract_sql(filepath):
35-
"""Parse a migration file and extract SQL strings from op.execute() calls."""
35+
class RevisionInfo:
36+
__slots__ = ("revision", "down_revision", "is_merge")
37+
38+
def __init__(self, revision, down_revision, is_merge):
39+
self.revision = revision
40+
self.down_revision = down_revision
41+
self.is_merge = is_merge
42+
43+
44+
def extract_revision_info(filepath):
45+
"""Parse a migration file to extract revision and down_revision from module-level assignments."""
3646
with open(filepath) as f:
3747
try:
3848
tree = ast.parse(f.read())
3949
except SyntaxError:
40-
return []
50+
return None
4151

42-
statements = []
52+
revision = None
53+
down_revision = None
4354

44-
for node in ast.walk(tree):
45-
if not isinstance(node, ast.Call):
55+
for node in ast.iter_child_nodes(tree):
56+
if not isinstance(node, ast.Assign):
4657
continue
47-
48-
if not (
49-
isinstance(node.func, ast.Attribute)
50-
and node.func.attr == "execute"
51-
and isinstance(node.func.value, ast.Name)
52-
and node.func.value.id == "op"
53-
and node.args
54-
):
58+
if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
5559
continue
5660

57-
sql = _extract_string(node.args[0])
58-
if sql:
59-
statements.append(sql)
61+
name = node.targets[0].id
62+
if name == "revision":
63+
if isinstance(node.value, ast.Constant) and isinstance(
64+
node.value.value, str
65+
):
66+
revision = node.value.value
67+
elif name == "down_revision":
68+
if isinstance(node.value, ast.Constant):
69+
if isinstance(node.value.value, str):
70+
down_revision = node.value.value
71+
elif node.value.value is None:
72+
down_revision = None
73+
elif isinstance(node.value, ast.Tuple):
74+
values = []
75+
for elt in node.value.elts:
76+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
77+
values.append(elt.value)
78+
down_revision = tuple(values)
79+
80+
if revision is None:
81+
return None
6082

61-
return statements
83+
is_merge = isinstance(down_revision, tuple)
84+
return RevisionInfo(
85+
revision=revision, down_revision=down_revision, is_merge=is_merge
86+
)
6287

6388

64-
def _extract_string(node):
65-
"""Extract a string value from an AST node, handling common wrappers."""
66-
# Direct string literal: op.execute("SQL")
67-
if isinstance(node, ast.Constant) and isinstance(node.value, str):
68-
return node.value
89+
def generate_sql(filepath):
90+
"""Run alembic upgrade --sql to generate the complete DDL for a migration."""
91+
info = extract_revision_info(filepath)
92+
if info is None:
93+
return None
6994

70-
# sa.text("SQL") or text("SQL")
71-
if isinstance(node, ast.Call) and node.args:
72-
if isinstance(node.func, ast.Attribute) and node.func.attr == "text":
73-
return _extract_string(node.args[0])
74-
if isinstance(node.func, ast.Name) and node.func.id == "text":
75-
return _extract_string(node.args[0])
95+
if info.is_merge:
96+
return None
7697

77-
return None
98+
base = info.down_revision if info.down_revision else "base"
99+
target = f"{base}:{info.revision}"
100+
101+
env = os.environ.copy()
102+
if "DATABASE_URL" not in env:
103+
env["DATABASE_URL"] = "postgresql://localhost/lint"
104+
105+
try:
106+
result = subprocess.run(
107+
["alembic", "upgrade", target, "--sql"],
108+
capture_output=True,
109+
text=True,
110+
env=env,
111+
)
112+
except FileNotFoundError:
113+
print(
114+
"squawk-alembic: alembic not found. Ensure alembic is installed in your environment.",
115+
file=sys.stderr,
116+
)
117+
return None
118+
119+
if result.returncode != 0:
120+
print(
121+
f"squawk-alembic: alembic upgrade --sql failed for {filepath}:\n{result.stderr}",
122+
file=sys.stderr,
123+
)
124+
return None
125+
126+
return result.stdout
78127

79128

80129
def check_autocommit_blocks(filepath):
@@ -90,6 +139,18 @@ def check_autocommit_blocks(filepath):
90139
return checker.warnings
91140

92141

142+
def _has_concurrent_kwarg(node):
143+
"""Check if an AST Call node has postgresql_concurrently=True."""
144+
for kw in node.keywords:
145+
if (
146+
kw.arg == "postgresql_concurrently"
147+
and isinstance(kw.value, ast.Constant)
148+
and kw.value.value is True
149+
):
150+
return True
151+
return False
152+
153+
93154
class _AutocommitChecker(ast.NodeVisitor):
94155
def __init__(self):
95156
self.warnings = []
@@ -113,17 +174,38 @@ def visit_With(self, node):
113174
def visit_Call(self, node):
114175
if (
115176
isinstance(node.func, ast.Attribute)
116-
and node.func.attr == "execute"
117177
and isinstance(node.func.value, ast.Name)
118178
and node.func.value.id == "op"
119-
and node.args
120179
):
121-
sql = _extract_string(node.args[0])
122-
if sql and "concurrently" in sql.lower() and not self._in_autocommit:
123-
self.warnings.append(node.lineno)
180+
# op.execute("...CONCURRENTLY...")
181+
if node.func.attr == "execute" and node.args:
182+
sql = _extract_string(node.args[0])
183+
if sql and "concurrently" in sql.lower() and not self._in_autocommit:
184+
self.warnings.append(node.lineno)
185+
186+
# op.create_index(..., postgresql_concurrently=True)
187+
# op.drop_index(..., postgresql_concurrently=True)
188+
if node.func.attr in ("create_index", "drop_index"):
189+
if _has_concurrent_kwarg(node) and not self._in_autocommit:
190+
self.warnings.append(node.lineno)
191+
124192
self.generic_visit(node)
125193

126194

195+
def _extract_string(node):
196+
"""Extract a string value from an AST node, handling common wrappers."""
197+
if isinstance(node, ast.Constant) and isinstance(node.value, str):
198+
return node.value
199+
200+
if isinstance(node, ast.Call) and node.args:
201+
if isinstance(node.func, ast.Attribute) and node.func.attr == "text":
202+
return _extract_string(node.args[0])
203+
if isinstance(node.func, ast.Name) and node.func.id == "text":
204+
return _extract_string(node.args[0])
205+
206+
return None
207+
208+
127209
def main():
128210
files = sys.argv[1:]
129211
if not files:
@@ -152,14 +234,12 @@ def main():
152234
)
153235
exit_code = 1
154236

155-
statements = extract_sql(filepath)
156-
if not statements:
237+
sql = generate_sql(filepath)
238+
if not sql:
157239
continue
158240

159-
combined = "\n".join(statements)
160-
161241
with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as tmp:
162-
tmp.write(combined)
242+
tmp.write(sql)
163243
tmp_path = tmp.name
164244

165245
try:

0 commit comments

Comments
 (0)