|
9 | 9 |
|
10 | 10 | import argparse |
11 | 11 | import logging |
| 12 | +import os |
12 | 13 | import shutil |
13 | 14 | import sys |
14 | 15 | import tempfile |
| 16 | +import time |
15 | 17 | from collections.abc import Iterator |
16 | 18 | from contextlib import contextmanager |
17 | 19 | from pathlib import Path |
18 | 20 | from typing import cast |
19 | 21 |
|
20 | | -from alembic import command |
21 | 22 | from alembic.config import Config |
| 23 | +from sqlalchemy import create_engine, text |
| 24 | +from sqlalchemy.engine import Connection |
| 25 | +from sqlalchemy.engine.url import make_url |
| 26 | +from sqlalchemy.pool import NullPool |
22 | 27 |
|
23 | 28 | import agent_control_server |
| 29 | +from agent_control_server.config import db_config |
| 30 | +from alembic import command |
| 31 | + |
| 32 | +LOGGER = logging.getLogger(__name__) |
| 33 | +_MIGRATION_LOCK_CLASS_ID = 0x4143544C # "ACTL" |
| 34 | +_MIGRATION_LOCK_OBJECT_ID = 0x4D494752 # "MIGR" |
| 35 | +_MIGRATION_LOCK_POLL_SECONDS = 2.0 |
| 36 | +_DEFAULT_MIGRATION_LOCK_TIMEOUT_SECONDS = 600.0 |
| 37 | +_MIGRATION_LOCK_TIMEOUT_ENV = "AGENT_CONTROL_MIGRATION_LOCK_TIMEOUT_SECONDS" |
| 38 | +_MIGRATION_LOCK_PARAMS = { |
| 39 | + "class_id": _MIGRATION_LOCK_CLASS_ID, |
| 40 | + "object_id": _MIGRATION_LOCK_OBJECT_ID, |
| 41 | +} |
24 | 42 |
|
25 | 43 |
|
26 | 44 | def _bundled_config() -> Config: |
@@ -58,6 +76,88 @@ def _runtime_bundled_config() -> Iterator[Config]: |
58 | 76 | yield cfg |
59 | 77 |
|
60 | 78 |
|
| 79 | +def _migration_url(cfg: Config) -> str: |
| 80 | + configured_url = cfg.get_main_option("sqlalchemy.url") |
| 81 | + if configured_url: |
| 82 | + return configured_url |
| 83 | + return db_config.get_url() |
| 84 | + |
| 85 | + |
| 86 | +def _migration_lock_timeout_seconds() -> float: |
| 87 | + raw_timeout = os.getenv(_MIGRATION_LOCK_TIMEOUT_ENV) |
| 88 | + if raw_timeout is None: |
| 89 | + return _DEFAULT_MIGRATION_LOCK_TIMEOUT_SECONDS |
| 90 | + |
| 91 | + try: |
| 92 | + timeout = float(raw_timeout) |
| 93 | + except ValueError as exc: |
| 94 | + raise RuntimeError(f"{_MIGRATION_LOCK_TIMEOUT_ENV} must be a number.") from exc |
| 95 | + |
| 96 | + if timeout <= 0: |
| 97 | + raise RuntimeError(f"{_MIGRATION_LOCK_TIMEOUT_ENV} must be greater than zero.") |
| 98 | + return timeout |
| 99 | + |
| 100 | + |
| 101 | +def _acquire_migration_lock(connection: Connection, timeout_seconds: float) -> None: |
| 102 | + deadline = time.monotonic() + timeout_seconds |
| 103 | + logged_wait = False |
| 104 | + |
| 105 | + while True: |
| 106 | + acquired = bool( |
| 107 | + connection.execute( |
| 108 | + text("SELECT pg_try_advisory_lock(:class_id, :object_id)"), |
| 109 | + _MIGRATION_LOCK_PARAMS, |
| 110 | + ).scalar_one() |
| 111 | + ) |
| 112 | + if acquired: |
| 113 | + LOGGER.info("Acquired Agent Control migration advisory lock.") |
| 114 | + return |
| 115 | + |
| 116 | + remaining = deadline - time.monotonic() |
| 117 | + if remaining <= 0: |
| 118 | + raise TimeoutError( |
| 119 | + f"Timed out after {timeout_seconds:g}s waiting for Agent Control " |
| 120 | + "migration advisory lock." |
| 121 | + ) |
| 122 | + |
| 123 | + if not logged_wait: |
| 124 | + LOGGER.info("Waiting for another Agent Control migration to finish.") |
| 125 | + logged_wait = True |
| 126 | + time.sleep(min(_MIGRATION_LOCK_POLL_SECONDS, remaining)) |
| 127 | + |
| 128 | + |
| 129 | +@contextmanager |
| 130 | +def _serialized_migration(cfg: Config, *, enabled: bool) -> Iterator[None]: |
| 131 | + if not enabled: |
| 132 | + yield |
| 133 | + return |
| 134 | + |
| 135 | + url = _migration_url(cfg) |
| 136 | + if make_url(url).get_backend_name() != "postgresql": |
| 137 | + yield |
| 138 | + return |
| 139 | + |
| 140 | + engine = create_engine(url, future=True, poolclass=NullPool) |
| 141 | + try: |
| 142 | + with engine.connect() as connection: |
| 143 | + _acquire_migration_lock(connection, _migration_lock_timeout_seconds()) |
| 144 | + try: |
| 145 | + yield |
| 146 | + finally: |
| 147 | + released = bool( |
| 148 | + connection.execute( |
| 149 | + text("SELECT pg_advisory_unlock(:class_id, :object_id)"), |
| 150 | + _MIGRATION_LOCK_PARAMS, |
| 151 | + ).scalar_one() |
| 152 | + ) |
| 153 | + if released: |
| 154 | + LOGGER.info("Released Agent Control migration advisory lock.") |
| 155 | + else: |
| 156 | + LOGGER.warning("Agent Control migration advisory lock was not held at release.") |
| 157 | + finally: |
| 158 | + engine.dispose() |
| 159 | + |
| 160 | + |
61 | 161 | def _build_parser() -> argparse.ArgumentParser: |
62 | 162 | parser = argparse.ArgumentParser( |
63 | 163 | prog="agent-control-migrate", |
@@ -104,18 +204,20 @@ def main(argv: list[str] | None = None) -> int: |
104 | 204 |
|
105 | 205 | try: |
106 | 206 | with _runtime_bundled_config() as cfg: |
107 | | - if parsed.command == "upgrade": |
108 | | - command.upgrade(cfg, parsed.revision, sql=parsed.sql) |
109 | | - elif parsed.command == "downgrade": |
110 | | - command.downgrade(cfg, parsed.revision, sql=parsed.sql) |
111 | | - elif parsed.command == "current": |
112 | | - command.current(cfg) |
113 | | - elif parsed.command == "history": |
114 | | - command.history(cfg) |
115 | | - elif parsed.command == "heads": |
116 | | - command.heads(cfg) |
117 | | - else: # pragma: no cover - argparse guarantees this cannot happen. |
118 | | - parser.error("missing command") |
| 207 | + should_lock = parsed.command in {"upgrade", "downgrade"} and not parsed.sql |
| 208 | + with _serialized_migration(cfg, enabled=should_lock): |
| 209 | + if parsed.command == "upgrade": |
| 210 | + command.upgrade(cfg, parsed.revision, sql=parsed.sql) |
| 211 | + elif parsed.command == "downgrade": |
| 212 | + command.downgrade(cfg, parsed.revision, sql=parsed.sql) |
| 213 | + elif parsed.command == "current": |
| 214 | + command.current(cfg) |
| 215 | + elif parsed.command == "history": |
| 216 | + command.history(cfg) |
| 217 | + elif parsed.command == "heads": |
| 218 | + command.heads(cfg) |
| 219 | + else: # pragma: no cover - argparse guarantees this cannot happen. |
| 220 | + parser.error("missing command") |
119 | 221 | except Exception as exc: |
120 | 222 | print(f"agent-control-migrate: {exc}", file=sys.stderr) |
121 | 223 | return 1 |
|
0 commit comments