Skip to content

Commit 0f0130a

Browse files
authored
FEAT: Implement DB schema tracking with alembic (#1631)
1 parent 6dab9f2 commit 0f0130a

21 files changed

Lines changed: 1321 additions & 91 deletions

.pre-commit-config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ repos:
3030
files: ^(doc/.*\.(py|ipynb|md)|doc/myst\.yml)$
3131
pass_filenames: false
3232
additional_dependencies: ['pyyaml']
33+
- id: enforce_alembic_revision_immutability
34+
name: Enforce Alembic Revision Immutability
35+
entry: python ./build_scripts/enforce_alembic_revision_immutability.py
36+
language: python
37+
files: ^pyrit/memory/alembic/versions/.*\.py$
38+
pass_filenames: false
39+
- id: memory-migrations-check
40+
name: Check Memory Migrations
41+
entry: python ./build_scripts/memory_migrations.py check
42+
language: system
43+
pass_filenames: false
44+
files: ^pyrit/memory/(memory_models\.py|alembic/.*|migration\.py)$
3345

3446
- repo: https://github.com/pre-commit/pre-commit-hooks
3547
rev: v5.0.0

.pyrit_conf_example

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ operation: op_trash_panda
9595
# - /path/to/.env
9696
# - /path/to/.env.local
9797

98+
# Schema Migration Check
99+
# ---------------------
100+
# If true, runs database schema migration on startup to ensure the database
101+
# is up to date with the latest PyRIT version.
102+
# Set to false to skip the check (e.g., for read-only access, testing, or
103+
# when managing migrations externally).
104+
check_schema: true
105+
98106
# Silent Mode
99107
# -----------
100108
# If true, suppresses print statements during initialization.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Migration history must be immutable. This hook enforces that by preventing deletion or updates to migration scripts.
6+
7+
Checks both staged changes (local pre-commit) and the full branch diff against origin/main (CI).
8+
"""
9+
10+
import subprocess
11+
import sys
12+
13+
_VERSIONS_PATH = "pyrit/memory/alembic/versions/"
14+
15+
16+
def _git(*args: str) -> str:
17+
result = subprocess.run(["git", *args], capture_output=True, text=True)
18+
return result.stdout.strip()
19+
20+
21+
def _has_non_add_changes(diff_spec: list[str]) -> bool:
22+
output = _git("diff", "--name-status", *diff_spec, "--", _VERSIONS_PATH)
23+
return any(line and not line.startswith("A") for line in output.splitlines())
24+
25+
26+
def has_revision_violations() -> bool:
27+
# Local pre-commit: check staged changes
28+
if _has_non_add_changes(["--cached"]):
29+
return True
30+
31+
# CI: check full branch diff against origin/main
32+
merge_base = _git("merge-base", "origin/main", "HEAD")
33+
return bool(merge_base and _has_non_add_changes([f"{merge_base}...HEAD"]))
34+
35+
36+
if __name__ == "__main__":
37+
if has_revision_violations():
38+
print("[ERROR] Migration scripts can only be added, not modified or deleted.")
39+
sys.exit(1)

build_scripts/memory_migrations.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import argparse
5+
import sys
6+
import tempfile
7+
from pathlib import Path
8+
9+
from alembic.util.exc import AutogenerateDiffsDetected
10+
from sqlalchemy import create_engine
11+
from sqlalchemy.engine import Engine
12+
13+
from pyrit.memory.migration import check_schema_migrations, generate_schema_migration, run_schema_migrations
14+
15+
# ANSI color codes
16+
_RED = "\033[91m"
17+
_RESET = "\033[0m"
18+
19+
20+
def _print_error(message: str) -> None:
21+
"""Print an error message in red to stderr."""
22+
print(f"{_RED}{message}{_RESET}", file=sys.stderr)
23+
24+
25+
def _create_temp_engine() -> tuple[Engine, Path]:
26+
"""Create a temp SQLite database upgraded to head and return engine and path."""
27+
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
28+
tmp_path = Path(tmp.name)
29+
engine = create_engine(f"sqlite:///{tmp_path}")
30+
run_schema_migrations(engine=engine)
31+
return engine, tmp_path
32+
33+
34+
def _cmd_generate(*, message: str, force: bool = False) -> None:
35+
"""Generate a new Alembic revision from model changes."""
36+
engine, tmp_path = _create_temp_engine()
37+
try:
38+
generate_schema_migration(engine=engine, message=message, force=force)
39+
print("Migration file generated. Review it carefully before committing.")
40+
except RuntimeError as e:
41+
_print_error(str(e))
42+
raise SystemExit(1) from e
43+
finally:
44+
engine.dispose()
45+
tmp_path.unlink(missing_ok=True)
46+
47+
48+
def _cmd_check() -> None:
49+
"""Verify all migrations apply cleanly and schema matches models."""
50+
engine, tmp_path = _create_temp_engine()
51+
try:
52+
check_schema_migrations(engine=engine)
53+
except AutogenerateDiffsDetected as e:
54+
_print_error(f"Migration check failed. Run 'generate' to create a migration. Error: {e}")
55+
raise SystemExit(1) from e
56+
finally:
57+
engine.dispose()
58+
tmp_path.unlink(missing_ok=True)
59+
60+
61+
def _build_parser() -> argparse.ArgumentParser:
62+
"""Build the CLI argument parser."""
63+
parser = argparse.ArgumentParser(
64+
description="PyRIT memory migration tool. Generate and validate migrations based on the current memory models."
65+
)
66+
sub = parser.add_subparsers(dest="command", required=True)
67+
68+
gen = sub.add_parser("generate", help="Generate a new migration from model changes.")
69+
gen.add_argument("-m", "--message", required=True, help="Migration message.")
70+
gen.add_argument("--force", action="store_true", help="Generate migration even if no changes detected.")
71+
72+
sub.add_parser("check", help="Verify all migrations apply cleanly and add up to the current memory models.")
73+
74+
return parser
75+
76+
77+
def main() -> int:
78+
"""Dispatch the selected migration command."""
79+
args = _build_parser().parse_args()
80+
81+
if args.command == "generate":
82+
_cmd_generate(message=args.message, force=args.force)
83+
elif args.command == "check":
84+
_cmd_check()
85+
86+
return 0
87+
88+
89+
if __name__ == "__main__":
90+
raise SystemExit(main())
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Memory Models & Migrations
2+
3+
This guide covers how to work with PyRIT's memory models — where they live, how to add or update them, and how the migration system works.
4+
5+
## Where Things Live
6+
7+
| What | Path |
8+
|---|---|
9+
| ORM models (SQLAlchemy) | `pyrit/memory/memory_models.py` |
10+
| Domain objects they map to | `pyrit/models/` (e.g. `MessagePiece`, `Score`, `Seed`, `AttackResult`, `ScenarioResult`) |
11+
| Alembic migration environment | `pyrit/memory/alembic/env.py` |
12+
| Migration revisions | `pyrit/memory/alembic/versions/` |
13+
| Migration helpers | `pyrit/memory/migration.py` |
14+
| CLI migration tool | `build_scripts/memory_migrations.py` |
15+
| Schema diagram | `doc/code/memory/10_schema_diagram.md` |
16+
17+
## Current Models
18+
19+
All models inherit from the SQLAlchemy `Base` declarative class and live in `memory_models.py`:
20+
21+
- **`PromptMemoryEntry`** — prompt/response data (`PromptMemoryEntries` table)
22+
- **`ScoreEntry`** — evaluation results (`ScoreEntries` table)
23+
- **`EmbeddingDataEntry`** — embeddings for semantic search (`EmbeddingData` table)
24+
- **`SeedEntry`** — dataset prompts/templates (`SeedPromptEntries` table)
25+
- **`AttackResultEntry`** — attack execution results (`AttackResultEntries` table)
26+
- **`ScenarioResultEntry`** — scenario execution metadata (`ScenarioResultEntries` table)
27+
28+
Each entry model has a corresponding domain object and conversion methods (e.g. `PromptMemoryEntry.__init__(entry: MessagePiece)` and `get_message_piece()`).
29+
30+
## Adding or Updating a Model
31+
32+
### 1. Edit the model
33+
34+
Make your changes in `pyrit/memory/memory_models.py`. Follow these conventions:
35+
36+
- Use `mapped_column()` with explicit types.
37+
- Use `CustomUUID` for all UUID columns (handles cross-database compatibility).
38+
- Add foreign keys where relationships exist.
39+
- Include `pyrit_version` on new entry models.
40+
41+
### 2. Generate a migration
42+
43+
```bash
44+
python build_scripts/memory_migrations.py generate -m "short description of change"
45+
```
46+
47+
This creates a new revision file under `pyrit/memory/alembic/versions/`. **Review the generated file carefully** — auto-generated migrations may need manual adjustments (e.g. for data migrations or default values).
48+
49+
### 3. Validate the migration
50+
51+
```bash
52+
python build_scripts/memory_migrations.py check
53+
```
54+
55+
This verifies the schema produced by running all migrations matches the current models. Both pre-commit hooks (see below) and CI run this check.
56+
57+
### 4. Update the schema diagram
58+
59+
If you changed the schema in a meaningful way (added a table, added a foreign key, etc.), update the Mermaid diagram in `doc/code/memory/10_schema_diagram.md`.
60+
61+
## How Migrations Run at Startup
62+
63+
Schema migrations are triggered inside each memory class constructor (`SQLiteMemory.__init__` and `AzureSQLMemory.__init__`). When `skip_schema_migration=False` (the default), the inherited `_run_schema_migration()` method on `MemoryInterface` runs:
64+
65+
```
66+
SQLiteMemory.__init__() / AzureSQLMemory.__init__()
67+
→ _run_schema_migration() # pyrit/memory/memory_interface.py
68+
→ run_schema_migrations(engine=...) # pyrit/memory/migration.py
69+
→ alembic upgrade head
70+
→ check_schema_migrations(engine=...) # pyrit/memory/migration.py
71+
→ alembic check
72+
```
73+
74+
Both SQLite and AzureSQL follow the same migration path: first `run_schema_migrations` applies any pending Alembic revisions (`alembic upgrade head`), then `check_schema_migrations` verifies the resulting schema matches the current models (`alembic check`). The behavior depends on database state:
75+
76+
| Database state | What happens |
77+
|---|---|
78+
| **Fresh (no tables)** | All migrations apply from scratch |
79+
| **Already versioned** | Only unapplied migrations run (idempotent) |
80+
| **Legacy (tables exist, no version tracking)** | Validates schema matches models, stamps current version, then upgrades. Raises `RuntimeError` on mismatch to prevent data corruption |
81+
82+
Migrations run inside a transaction (`engine.begin()`), so a failed migration rolls back cleanly. The version tracking table is `pyrit_memory_alembic_version`.
83+
84+
Users can skip migrations by passing `skip_schema_migration=True` to the memory class constructor. When using `initialize_pyrit_async()`, this can be forwarded via `**memory_instance_kwargs`:
85+
86+
```python
87+
await initialize_pyrit_async("SQLite", skip_schema_migration=True)
88+
```
89+
90+
## Important Rules
91+
92+
### Migration revisions are immutable
93+
94+
Once a migration revision is committed, it **must not be modified or deleted**. This is enforced by a pre-commit hook (`enforce_alembic_revision_immutability`). If you need to fix a migration, create a new revision instead.
95+
96+
### Pre-commit hooks
97+
98+
Two hooks run automatically when you touch memory-related files:
99+
100+
1. **`enforce_alembic_revision_immutability`** — blocks modifications/deletions to existing revision files.
101+
2. **`memory-migrations-check`** — runs `memory_migrations.py check` to verify the schema is in sync.
102+
103+
These hooks trigger on changes to `pyrit/memory/memory_models.py`, `pyrit/memory/migration.py`, and files under `pyrit/memory/alembic/`.

doc/myst.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ project:
5353
- file: contributing/8_pre_commit.md
5454
- file: contributing/9_exception.md
5555
- file: contributing/10_release_process.md
56+
- file: contributing/11_memory_models.md
5657
- file: gui/0_gui.md
5758
- file: scanner/0_scanner.md
5859
children:

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ classifiers = [
2828
requires-python = ">=3.10, <3.15"
2929
dependencies = [
3030
"aiofiles>=24,<25",
31+
"alembic>=1.16.0",
3132
"appdirs>=1.4.0",
3233
"art>=6.5.0",
3334
"av>=14.0.0",
@@ -201,6 +202,8 @@ include = ["pyrit", "pyrit.*"]
201202
[tool.setuptools.package-data]
202203
pyrit = [
203204
"backend/frontend/**/*",
205+
"memory/alembic/**/*",
206+
"memory/alembic.ini",
204207
"py.typed"
205208
]
206209

pyrit/memory/alembic/env.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from alembic import context
5+
from sqlalchemy.engine import Connection
6+
7+
from pyrit.memory.memory_models import Base
8+
from pyrit.memory.migration import PYRIT_MEMORY_ALEMBIC_VERSION_TABLE
9+
10+
config = context.config
11+
connection: Connection | None = config.attributes.get("connection")
12+
target_metadata = Base.metadata
13+
14+
if connection is None:
15+
raise RuntimeError("No connection found for Alembic migration")
16+
17+
context.configure(
18+
connection=connection,
19+
target_metadata=target_metadata,
20+
compare_type=True,
21+
version_table=PYRIT_MEMORY_ALEMBIC_VERSION_TABLE,
22+
)
23+
with context.begin_transaction():
24+
context.run_migrations()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
${message}.
6+
7+
Revision ID: ${up_revision}
8+
Revises: ${down_revision | comma,n}
9+
Create Date: ${create_date}
10+
"""
11+
12+
from collections.abc import Sequence
13+
14+
import sqlalchemy as sa
15+
from alembic import op
16+
${imports if imports else ""}
17+
18+
# revision identifiers, used by Alembic.
19+
revision: str = "${up_revision}"
20+
down_revision: str | None = ${repr(down_revision).replace("'", '"')}
21+
branch_labels: str | Sequence[str] | None = ${repr(branch_labels).replace("'", '"')}
22+
depends_on: str | Sequence[str] | None = ${repr(depends_on).replace("'", '"')}
23+
24+
25+
def upgrade() -> None:
26+
"""Apply this schema upgrade."""
27+
${upgrades if upgrades else "pass"}
28+
29+
30+
def downgrade() -> None:
31+
"""Revert this schema upgrade."""
32+
${downgrades if downgrades else "pass"}

0 commit comments

Comments
 (0)