Skip to content

Commit da5a132

Browse files
fix(server): serialize alembic migrations
1 parent 902dbd3 commit da5a132

2 files changed

Lines changed: 206 additions & 13 deletions

File tree

server/src/agent_control_server/migrate.py

Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,36 @@
99

1010
import argparse
1111
import logging
12+
import os
1213
import shutil
1314
import sys
1415
import tempfile
16+
import time
1517
from collections.abc import Iterator
1618
from contextlib import contextmanager
1719
from pathlib import Path
1820
from typing import cast
1921

20-
from alembic import command
2122
from alembic.config import Config
23+
from sqlalchemy import create_engine, text
24+
from sqlalchemy.engine import Connection
25+
from sqlalchemy.engine.url import make_url
26+
from sqlalchemy.pool import NullPool
2227

2328
import agent_control_server
29+
from agent_control_server.config import db_config
30+
from alembic import command
31+
32+
LOGGER = logging.getLogger(__name__)
33+
_MIGRATION_LOCK_CLASS_ID = 0x4143544C # "ACTL"
34+
_MIGRATION_LOCK_OBJECT_ID = 0x4D494752 # "MIGR"
35+
_MIGRATION_LOCK_POLL_SECONDS = 2.0
36+
_DEFAULT_MIGRATION_LOCK_TIMEOUT_SECONDS = 600.0
37+
_MIGRATION_LOCK_TIMEOUT_ENV = "AGENT_CONTROL_MIGRATION_LOCK_TIMEOUT_SECONDS"
38+
_MIGRATION_LOCK_PARAMS = {
39+
"class_id": _MIGRATION_LOCK_CLASS_ID,
40+
"object_id": _MIGRATION_LOCK_OBJECT_ID,
41+
}
2442

2543

2644
def _bundled_config() -> Config:
@@ -58,6 +76,88 @@ def _runtime_bundled_config() -> Iterator[Config]:
5876
yield cfg
5977

6078

79+
def _migration_url(cfg: Config) -> str:
80+
configured_url = cfg.get_main_option("sqlalchemy.url")
81+
if configured_url:
82+
return configured_url
83+
return db_config.get_url()
84+
85+
86+
def _migration_lock_timeout_seconds() -> float:
87+
raw_timeout = os.getenv(_MIGRATION_LOCK_TIMEOUT_ENV)
88+
if raw_timeout is None:
89+
return _DEFAULT_MIGRATION_LOCK_TIMEOUT_SECONDS
90+
91+
try:
92+
timeout = float(raw_timeout)
93+
except ValueError as exc:
94+
raise RuntimeError(f"{_MIGRATION_LOCK_TIMEOUT_ENV} must be a number.") from exc
95+
96+
if timeout <= 0:
97+
raise RuntimeError(f"{_MIGRATION_LOCK_TIMEOUT_ENV} must be greater than zero.")
98+
return timeout
99+
100+
101+
def _acquire_migration_lock(connection: Connection, timeout_seconds: float) -> None:
102+
deadline = time.monotonic() + timeout_seconds
103+
logged_wait = False
104+
105+
while True:
106+
acquired = bool(
107+
connection.execute(
108+
text("SELECT pg_try_advisory_lock(:class_id, :object_id)"),
109+
_MIGRATION_LOCK_PARAMS,
110+
).scalar_one()
111+
)
112+
if acquired:
113+
LOGGER.info("Acquired Agent Control migration advisory lock.")
114+
return
115+
116+
remaining = deadline - time.monotonic()
117+
if remaining <= 0:
118+
raise TimeoutError(
119+
f"Timed out after {timeout_seconds:g}s waiting for Agent Control "
120+
"migration advisory lock."
121+
)
122+
123+
if not logged_wait:
124+
LOGGER.info("Waiting for another Agent Control migration to finish.")
125+
logged_wait = True
126+
time.sleep(min(_MIGRATION_LOCK_POLL_SECONDS, remaining))
127+
128+
129+
@contextmanager
130+
def _serialized_migration(cfg: Config, *, enabled: bool) -> Iterator[None]:
131+
if not enabled:
132+
yield
133+
return
134+
135+
url = _migration_url(cfg)
136+
if make_url(url).get_backend_name() != "postgresql":
137+
yield
138+
return
139+
140+
engine = create_engine(url, future=True, poolclass=NullPool)
141+
try:
142+
with engine.connect() as connection:
143+
_acquire_migration_lock(connection, _migration_lock_timeout_seconds())
144+
try:
145+
yield
146+
finally:
147+
released = bool(
148+
connection.execute(
149+
text("SELECT pg_advisory_unlock(:class_id, :object_id)"),
150+
_MIGRATION_LOCK_PARAMS,
151+
).scalar_one()
152+
)
153+
if released:
154+
LOGGER.info("Released Agent Control migration advisory lock.")
155+
else:
156+
LOGGER.warning("Agent Control migration advisory lock was not held at release.")
157+
finally:
158+
engine.dispose()
159+
160+
61161
def _build_parser() -> argparse.ArgumentParser:
62162
parser = argparse.ArgumentParser(
63163
prog="agent-control-migrate",
@@ -104,18 +204,20 @@ def main(argv: list[str] | None = None) -> int:
104204

105205
try:
106206
with _runtime_bundled_config() as cfg:
107-
if parsed.command == "upgrade":
108-
command.upgrade(cfg, parsed.revision, sql=parsed.sql)
109-
elif parsed.command == "downgrade":
110-
command.downgrade(cfg, parsed.revision, sql=parsed.sql)
111-
elif parsed.command == "current":
112-
command.current(cfg)
113-
elif parsed.command == "history":
114-
command.history(cfg)
115-
elif parsed.command == "heads":
116-
command.heads(cfg)
117-
else: # pragma: no cover - argparse guarantees this cannot happen.
118-
parser.error("missing command")
207+
should_lock = parsed.command in {"upgrade", "downgrade"} and not parsed.sql
208+
with _serialized_migration(cfg, enabled=should_lock):
209+
if parsed.command == "upgrade":
210+
command.upgrade(cfg, parsed.revision, sql=parsed.sql)
211+
elif parsed.command == "downgrade":
212+
command.downgrade(cfg, parsed.revision, sql=parsed.sql)
213+
elif parsed.command == "current":
214+
command.current(cfg)
215+
elif parsed.command == "history":
216+
command.history(cfg)
217+
elif parsed.command == "heads":
218+
command.heads(cfg)
219+
else: # pragma: no cover - argparse guarantees this cannot happen.
220+
parser.error("missing command")
119221
except Exception as exc:
120222
print(f"agent-control-migrate: {exc}", file=sys.stderr)
121223
return 1

server/tests/test_migrate.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,53 @@
22

33
from pathlib import Path
44

5+
from alembic.config import Config
6+
57
import agent_control_server
68
from agent_control_server import migrate
79

810

11+
class _FakeResult:
12+
def __init__(self, value: bool) -> None:
13+
self.value = value
14+
15+
def scalar_one(self) -> bool:
16+
return self.value
17+
18+
19+
class _FakeConnection:
20+
def __init__(self, lock_results: list[bool]) -> None:
21+
self.lock_results = lock_results
22+
self.statements: list[str] = []
23+
24+
def __enter__(self) -> _FakeConnection:
25+
return self
26+
27+
def __exit__(self, *args: object) -> None:
28+
return None
29+
30+
def execute(self, statement: object, params: object) -> _FakeResult:
31+
statement_text = str(statement)
32+
self.statements.append(statement_text)
33+
if "pg_try_advisory_lock" in statement_text:
34+
return _FakeResult(self.lock_results.pop(0))
35+
if "pg_advisory_unlock" in statement_text:
36+
return _FakeResult(True)
37+
raise AssertionError(f"unexpected SQL statement: {statement_text}")
38+
39+
40+
class _FakeEngine:
41+
def __init__(self, connection: _FakeConnection) -> None:
42+
self.connection = connection
43+
self.disposed = False
44+
45+
def connect(self) -> _FakeConnection:
46+
return self.connection
47+
48+
def dispose(self) -> None:
49+
self.disposed = True
50+
51+
952
def test_bundled_config_omits_injected_version_init(
1053
tmp_path: Path,
1154
monkeypatch,
@@ -31,3 +74,51 @@ def test_bundled_config_omits_injected_version_init(
3174
assert not (script_location / "versions" / "__init__.py").exists()
3275

3376
assert not script_location.exists()
77+
78+
79+
def test_serialized_migration_skips_lock_for_non_postgres_url(monkeypatch) -> None:
80+
cfg = Config()
81+
cfg.set_main_option("sqlalchemy.url", "sqlite:///agent-control.db")
82+
83+
def fail_create_engine(*args: object, **kwargs: object) -> object:
84+
raise AssertionError("non-postgres migrations should not create a lock connection")
85+
86+
monkeypatch.setattr(migrate, "create_engine", fail_create_engine)
87+
88+
with migrate._serialized_migration(cfg, enabled=True):
89+
pass
90+
91+
92+
def test_serialized_migration_acquires_and_releases_postgres_lock(monkeypatch) -> None:
93+
cfg = Config()
94+
cfg.set_main_option("sqlalchemy.url", "postgresql+psycopg://user:pass@postgres/db")
95+
connection = _FakeConnection([False, True])
96+
engine = _FakeEngine(connection)
97+
sleeps: list[float] = []
98+
99+
monkeypatch.setattr(migrate, "create_engine", lambda *args, **kwargs: engine)
100+
monkeypatch.setattr(migrate.time, "sleep", lambda seconds: sleeps.append(seconds))
101+
102+
with migrate._serialized_migration(cfg, enabled=True):
103+
pass
104+
105+
assert connection.statements == [
106+
"SELECT pg_try_advisory_lock(:class_id, :object_id)",
107+
"SELECT pg_try_advisory_lock(:class_id, :object_id)",
108+
"SELECT pg_advisory_unlock(:class_id, :object_id)",
109+
]
110+
assert sleeps == [2.0]
111+
assert engine.disposed
112+
113+
114+
def test_serialized_migration_respects_disabled_lock(monkeypatch) -> None:
115+
cfg = Config()
116+
cfg.set_main_option("sqlalchemy.url", "postgresql+psycopg://user:pass@postgres/db")
117+
118+
def fail_create_engine(*args: object, **kwargs: object) -> object:
119+
raise AssertionError("disabled migration lock should not create a lock connection")
120+
121+
monkeypatch.setattr(migrate, "create_engine", fail_create_engine)
122+
123+
with migrate._serialized_migration(cfg, enabled=False):
124+
pass

0 commit comments

Comments
 (0)