Skip to content

Commit fb0c992

Browse files
feat(checkpoint): state migration registry, types, errors, builder surface
Implements pipeline-utilities spec §10.12 (proposal 0014). - New errors: CheckpointStateMigrationMissing, CheckpointStateMigrationFailed. Both non-transient per §10.10. The missing-chain error carries from_version / to_version / registered_migrations_count / registry_description for actionable operator diagnostics. - New types: StateMigration (frozen dataclass — from_version, to_version, migrate callable) and MigrationRegistry (BFS chain resolution + ambiguity detection per §10.12.2). - Multi-shortest-path detection: when BFS finds a shortest path AND a second distinct path of equal length exists, the registry raises ValueError per the spec's ambiguous-chain rule. Resume surfaces this as CheckpointStateMigrationMissing with the ambiguity description in the payload. - State.schema_version: ClassVar[str] = '' (per spec §10.2's per-language carve-out). Empty-string sentinel; the framework reads type(state).schema_version at save time. - Checkpointer Protocol: supports_state_migration: ClassVar[bool] marker per §10.12.1. InMemoryCheckpointer: False (typed in- memory references can't expose a class-independent intermediate). SQLiteCheckpointer: True in JSON mode, False in pickle mode (pickle holds class identity and round-trips to typed instances; can't bridge versions). - GraphBuilder.with_state_migration / with_state_migrations thread a populated MigrationRegistry into CompiledGraph at compile time. - Resume-path routing (compiled.py): version mismatch → unsupported-backend check → registry lookup → chain application (with per-migration failure wrap) → final deserialization. The post-migration deserialization failure still surfaces as CheckpointRecordInvalid per §10.12.4; pre-migration version mismatch routes through the new two categories. Order matters; documented inline so a future reader doesn't swap it back. - Parent-state migration: same chain applied to each entry of parent_states in lockstep with the outer state per §10.12.2. Code comment records the spec-mandated equivalence so future contributors don't add per-parent metadata without a follow-on proposal. - Drop the CHECKPOINT_SCHEMA_VERSION = '1' constant: per Q1 spec answer, the old backend-internal record-shape role had no spec slot anyway. SQLiteCheckpointer no longer rejects records with non-default versions on load — that routing is now the engine's concern at resume time. Existing records carrying schema_version='1' get reinterpreted as user-facing v1 identifiers (single-user dev, no compat shim needed per Chris's note).
1 parent ddc99f2 commit fb0c992

10 files changed

Lines changed: 512 additions & 49 deletions

File tree

src/openarmature/checkpoint/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
CheckpointNotFound,
2727
CheckpointRecordInvalid,
2828
CheckpointSaveFailed,
29+
CheckpointStateMigrationFailed,
30+
CheckpointStateMigrationMissing,
2931
)
32+
from .migration import MigrationRegistry, StateMigration
3033
from .protocol import (
31-
CHECKPOINT_SCHEMA_VERSION,
3234
Checkpointer,
3335
CheckpointFilter,
3436
CheckpointRecord,
@@ -37,17 +39,20 @@
3739
)
3840

3941
__all__ = [
40-
"CHECKPOINT_SCHEMA_VERSION",
4142
"CheckpointError",
4243
"CheckpointFilter",
4344
"CheckpointNotFound",
4445
"CheckpointRecord",
4546
"CheckpointRecordInvalid",
4647
"CheckpointSaveFailed",
48+
"CheckpointStateMigrationFailed",
49+
"CheckpointStateMigrationMissing",
4750
"CheckpointSummary",
4851
"Checkpointer",
4952
"InMemoryCheckpointer",
53+
"MigrationRegistry",
5054
"NodePosition",
5155
"SQLiteCheckpointer",
5256
"SerializationMode",
57+
"StateMigration",
5358
]

src/openarmature/checkpoint/backends/memory.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import asyncio
1414
from collections.abc import Iterable
15+
from typing import ClassVar
1516

1617
from ..protocol import CheckpointFilter, CheckpointRecord, CheckpointSummary
1718

@@ -28,8 +29,21 @@ class InMemoryCheckpointer:
2829
Pydantic state instance the engine produces is what comes back
2930
from :meth:`load` — no serialization round-trip. (This is the
3031
feature: tests can assert on the saved state's identity.)
32+
33+
**State-migration eligibility:** none. Per spec §10.12.1, a
34+
backend supports migration only when it can expose a structural
35+
intermediate form of the loaded state independent of the current
36+
state class. This backend holds live typed instances by
37+
reference, so a version mismatch on resume raises
38+
``CheckpointRecordInvalid`` rather than consulting the
39+
migration registry.
3140
"""
3241

42+
# Per spec §10.12.1: in-memory storage holds live typed-state
43+
# references, so there's no class-independent intermediate form
44+
# the migration registry could consume.
45+
supports_state_migration: ClassVar[bool] = False
46+
3347
def __init__(self) -> None:
3448
self._records: dict[str, CheckpointRecord] = {}
3549
self._lock = asyncio.Lock()

src/openarmature/checkpoint/backends/sqlite.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
from ..errors import CheckpointRecordInvalid
4444
from ..protocol import (
45-
CHECKPOINT_SCHEMA_VERSION,
4645
CheckpointFilter,
4746
CheckpointRecord,
4847
CheckpointSummary,
@@ -109,6 +108,13 @@ def __init__(
109108
self._serialization: SerializationMode = serialization
110109
self._lock = asyncio.Lock()
111110
self._initialized = False
111+
# Per spec §10.12.1, a backend supports state migration only
112+
# when it can expose a structural intermediate form of the
113+
# loaded state that is independent of the current state
114+
# class. JSON serialization satisfies this (loads to dicts);
115+
# pickle holds class identity and round-trips to typed
116+
# instances, so it cannot bridge a schema-version mismatch.
117+
self.supports_state_migration: bool = serialization == "json"
112118

113119
def _connect(self) -> sqlite3.Connection:
114120
conn = sqlite3.connect(self._path)
@@ -230,12 +236,12 @@ def _do() -> tuple[Any, ...] | None:
230236
schema_version,
231237
recorded_serialization,
232238
) = row
233-
if schema_version != CHECKPOINT_SCHEMA_VERSION:
234-
raise CheckpointRecordInvalid(
235-
invocation_id,
236-
f"persisted schema_version={schema_version!r} does not match "
237-
f"current {CHECKPOINT_SCHEMA_VERSION!r}",
238-
)
239+
# Note: per spec §10.12 (proposal 0014), version mismatches
240+
# are no longer rejected at the backend boundary. The engine
241+
# routes mismatches through the migration registry on resume
242+
# (CheckpointStateMigrationMissing if no chain, else applies
243+
# the chain). The backend just round-trips the version
244+
# identifier as opaque data.
239245
state = self._decode(state_blob, recorded_serialization, invocation_id)
240246
position_dicts = self._decode(positions_blob, recorded_serialization, invocation_id)
241247
parent_states = self._decode(parent_states_blob, recorded_serialization, invocation_id)

src/openarmature/checkpoint/errors.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from __future__ import annotations
1919

20+
from typing import Any
21+
2022

2123
class CheckpointError(Exception):
2224
"""Base for all checkpoint errors. Each subclass carries a
@@ -56,10 +58,17 @@ def __init__(self, invocation_id: str, cause: BaseException) -> None:
5658

5759
class CheckpointRecordInvalid(CheckpointError):
5860
"""Raised when ``Checkpointer.load(X)`` returns a record whose
59-
schema is incompatible with the current graph (state shape
60-
mismatch, missing required fields, or
61-
``schema_version`` mismatch). Non-transient — the persisted
62-
record was written by an incompatible version of the engine."""
61+
schema is incompatible with the current graph: state shape
62+
mismatch, missing required fields, OR a post-migration state
63+
that fails to deserialize against the current state class (per
64+
spec §10.12.4). Non-transient.
65+
66+
Note: raw ``schema_version`` mismatches no longer route here.
67+
They now flow through ``CheckpointStateMigrationMissing`` (no
68+
chain registered) or ``CheckpointStateMigrationFailed`` (chain
69+
application raised) per spec §10.10's three-way category
70+
distinction.
71+
"""
6372

6473
category = "checkpoint_record_invalid"
6574

@@ -68,9 +77,69 @@ def __init__(self, invocation_id: str, message: str) -> None:
6877
self.invocation_id = invocation_id
6978

7079

80+
class CheckpointStateMigrationMissing(CheckpointError):
81+
"""Raised on resume when the saved record's ``schema_version``
82+
does not match the current state class's ``schema_version`` AND
83+
no chain of registered migrations bridges the two. Non-transient
84+
per spec §10.10 — the user MUST register a migration (or pin
85+
their state to the saved version) for the resume to succeed.
86+
87+
Carries the saved-from / current-to versions and a description
88+
of the registered migration set so the user can see what
89+
migrations are available.
90+
"""
91+
92+
category = "checkpoint_state_migration_missing"
93+
94+
from_version: str
95+
to_version: str
96+
registered_migrations_count: int
97+
registry_description: str
98+
99+
def __init__(
100+
self,
101+
*args: Any,
102+
from_version: str,
103+
to_version: str,
104+
registered_migrations_count: int,
105+
registry_description: str,
106+
) -> None:
107+
super().__init__(*args)
108+
self.from_version = from_version
109+
self.to_version = to_version
110+
self.registered_migrations_count = registered_migrations_count
111+
self.registry_description = registry_description
112+
113+
114+
class CheckpointStateMigrationFailed(CheckpointError):
115+
"""Raised on resume when a registered migration function raises
116+
during chain application (per spec §10.12.2). The migration's
117+
exception is preserved as ``__cause__``. Non-transient by
118+
default: a buggy migration is deterministic, so retrying
119+
without changing the migration code will not succeed.
120+
"""
121+
122+
category = "checkpoint_state_migration_failed"
123+
124+
from_version: str
125+
to_version: str
126+
127+
def __init__(
128+
self,
129+
*args: Any,
130+
from_version: str,
131+
to_version: str,
132+
) -> None:
133+
super().__init__(*args)
134+
self.from_version = from_version
135+
self.to_version = to_version
136+
137+
71138
__all__ = [
72139
"CheckpointError",
73140
"CheckpointNotFound",
74141
"CheckpointRecordInvalid",
75142
"CheckpointSaveFailed",
143+
"CheckpointStateMigrationFailed",
144+
"CheckpointStateMigrationMissing",
76145
]
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""State migration types and registry.
2+
3+
Realizes pipeline-utilities §10.12 (proposal 0014). A
4+
``StateMigration`` describes one edge in the migration graph;
5+
``MigrationRegistry`` holds the ordered set and resolves chains
6+
via BFS. Ambiguity (duplicate ``(from, to)`` pairs OR multiple
7+
distinct shortest paths between the same source/sink) is a
8+
configuration-style error per §10.12.1 / §10.12.2.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from collections import deque
14+
from collections.abc import Callable, Iterator
15+
from dataclasses import dataclass
16+
from typing import Any
17+
18+
19+
@dataclass(frozen=True)
20+
class StateMigration:
21+
"""One edge in the migration graph.
22+
23+
``migrate`` receives the most-deserialized form the backend can
24+
expose that is still independent of the current state class
25+
(a plain ``dict`` for JSON-backed backends). It MUST return a
26+
value of the same kind, suitable for the next migration in the
27+
chain (or for final deserialization into the current state class).
28+
29+
Migrations MUST be pure: deterministic, no I/O, no implicit
30+
state. The framework does not police purity per spec §10.12.2
31+
("the contract is documented, not policed"); violating it
32+
risks non-deterministic resume.
33+
"""
34+
35+
from_version: str
36+
to_version: str
37+
migrate: Callable[[Any], Any]
38+
39+
40+
class MigrationRegistry:
41+
"""Ordered set of registered migrations + BFS chain resolution.
42+
43+
Registration-time invariants:
44+
45+
- Two migrations with the same ``from_version`` AND
46+
``to_version`` raise ``ValueError`` (chain ambiguity per
47+
§10.12.1).
48+
- Two migrations with the same ``from_version`` and different
49+
``to_version`` are permitted (branched migration graph;
50+
chain resolution picks a path).
51+
52+
Resolution-time semantics (per §10.12.2):
53+
54+
- BFS from ``record.schema_version`` to
55+
``current.schema_version``. BFS naturally finds the shortest
56+
path.
57+
- Empty registry on mismatch → no path → caller raises
58+
``CheckpointStateMigrationMissing``.
59+
- Non-empty registry with no connecting path → same.
60+
- Found a unique shortest path → return ordered list.
61+
- Found multiple distinct shortest paths (same edge count,
62+
different edge sequences) → raise ``ValueError`` per
63+
§10.12.2's ambiguous-chain rule. Spec accepts load-time
64+
detection.
65+
"""
66+
67+
def __init__(self) -> None:
68+
self._migrations: dict[tuple[str, str], StateMigration] = {}
69+
self._edges: dict[str, list[StateMigration]] = {}
70+
71+
def register(self, migration: StateMigration) -> None:
72+
key = (migration.from_version, migration.to_version)
73+
if key in self._migrations:
74+
raise ValueError(
75+
f"duplicate state migration {migration.from_version!r}→"
76+
f"{migration.to_version!r} registered; chain would be ambiguous"
77+
)
78+
self._migrations[key] = migration
79+
self._edges.setdefault(migration.from_version, []).append(migration)
80+
81+
def __iter__(self) -> Iterator[StateMigration]:
82+
return iter(self._migrations.values())
83+
84+
def __len__(self) -> int:
85+
return len(self._migrations)
86+
87+
def resolve_chain(
88+
self,
89+
from_version: str,
90+
to_version: str,
91+
) -> list[StateMigration] | None:
92+
"""Return an ordered chain of migrations bridging the two
93+
versions, or ``None`` if no chain exists.
94+
95+
Raises ``ValueError`` if multiple distinct shortest paths
96+
exist (ambiguous chain per §10.12.2).
97+
"""
98+
if from_version == to_version:
99+
return []
100+
101+
# BFS that records every shortest-length path. If multiple
102+
# paths share the minimum length, the chain is ambiguous.
103+
# Standard BFS finds the shortest distance; the path-recording
104+
# variant lets us detect ambiguity without a second pass.
105+
# ``frontier`` items are (version, path_so_far).
106+
frontier: deque[tuple[str, list[StateMigration]]] = deque()
107+
frontier.append((from_version, []))
108+
shortest_paths: list[list[StateMigration]] = []
109+
shortest_length: int | None = None
110+
# ``distances`` tracks the BFS layer at which each node was
111+
# first seen. Frontier entries past the shortest_length layer
112+
# are pruned.
113+
distances: dict[str, int] = {from_version: 0}
114+
115+
while frontier:
116+
version, path = frontier.popleft()
117+
depth = len(path)
118+
# Stop expanding once we've moved past the shortest target.
119+
if shortest_length is not None and depth >= shortest_length:
120+
continue
121+
for edge in self._edges.get(version, []):
122+
next_version = edge.to_version
123+
next_path = path + [edge]
124+
if next_version == to_version:
125+
if shortest_length is None:
126+
shortest_length = len(next_path)
127+
if len(next_path) == shortest_length:
128+
shortest_paths.append(next_path)
129+
continue
130+
# Cycle-avoidance: a node revisited at the same or
131+
# deeper BFS layer can't contribute to a strict-
132+
# shortest path. Allow re-entry only when the new
133+
# arrival is at the same layer as the first arrival
134+
# (distinct shortest paths through the same node).
135+
prior_depth = distances.get(next_version)
136+
if prior_depth is not None and prior_depth < depth + 1:
137+
continue
138+
distances[next_version] = depth + 1
139+
frontier.append((next_version, next_path))
140+
141+
if not shortest_paths:
142+
return None
143+
if len(shortest_paths) > 1:
144+
descriptions = [" → ".join([from_version, *(e.to_version for e in p)]) for p in shortest_paths]
145+
raise ValueError(
146+
f"ambiguous migration chain from {from_version!r} to "
147+
f"{to_version!r}: multiple distinct shortest paths exist "
148+
f"({descriptions}); register fewer migrations or pick a "
149+
f"single canonical route"
150+
)
151+
return shortest_paths[0]
152+
153+
def describe(self) -> str:
154+
"""Human-readable description of the registered set, used
155+
in the ``CheckpointStateMigrationMissing`` error payload.
156+
Empty registry returns ``"<no migrations registered>"``.
157+
"""
158+
if not self._migrations:
159+
return "<no migrations registered>"
160+
return "\n".join(f"{m.from_version}{m.to_version}" for m in self._migrations.values())
161+
162+
163+
__all__ = ["MigrationRegistry", "StateMigration"]

0 commit comments

Comments
 (0)