Skip to content

Commit bb02aff

Browse files
Extend conformance harness for proposals 0027/0028/0029
Fixture range expands to 048-056 (was 048-054), picking up 055 (schema_version declared class) and 056 (fan-out count drift) on the v0.22.1 spec submodule. Per-instance saved-record matcher gains result_is_error (boolean discriminator) and result_present (existence check without constraining shape) per proposal 0027; the pre-0027 result_kind: error shape heuristic is no longer recognized by the harness since the upstream fixtures retired that key. For proposal 0028, the harness honors a state.schema_version directive on the fixture by mutating the constructed state class's schema_version attribute, plus a runtime_state_subclass: {schema_version: <v>} directive that builds a Python subclass with the override and instantiates initial_state from it. The every_save_assertions block iterates every captured save and asserts each declared key matches — catches save sites that read schema_version from type(state) at save time. For proposal 0029, the harness recognizes resume_with_modified_items on the resume block by installing load-time state overrides on the capturing checkpointer; the engine reads back the mutated outer state on resume and the fan-out node's count-drift check raises. The resume.expected_error matcher catches both CheckpointError and RuntimeGraphError and asserts the category — CheckpointRecordInvalid isn't a RuntimeGraphError subclass so a stricter match would miss the proposal-0029 path.
1 parent 113f946 commit bb02aff

1 file changed

Lines changed: 164 additions & 21 deletions

File tree

tests/conformance/test_checkpoint.py

Lines changed: 164 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@
5959
)
6060

6161
# Conformance fixture range: 024-031 minus 028 are the proposal-0008
62-
# set; 048-054 are the proposal-0009 per-instance-resume set. 028
63-
# (fan-out atomic-restart) was REMOVED in spec v0.18.0 when proposal
64-
# 0009 superseded its contract, so it is explicitly excluded from the
65-
# set rather than relying on the test runner's file-glob to filter
66-
# the missing fixture out.
67-
_CHECKPOINT_FIXTURE_NUMBERS: frozenset[int] = frozenset((set(range(24, 32)) - {28}) | set(range(48, 55)))
62+
# set; 048-054 are the proposal-0009 per-instance-resume set; 055
63+
# (schema_version declared class — proposal 0028) and 056 (fan-out
64+
# count drift — proposal 0029) are the follow-on bundle. 028 (fan-out
65+
# atomic-restart) was REMOVED in spec v0.18.0 when proposal 0009
66+
# superseded its contract, so it is explicitly excluded from the set
67+
# rather than relying on the test runner's file-glob to filter the
68+
# missing fixture out.
69+
_CHECKPOINT_FIXTURE_NUMBERS: frozenset[int] = frozenset((set(range(24, 32)) - {28}) | set(range(48, 57)))
6870

6971
# Fixtures that need resume-aware test seams the conformance adapter
7072
# doesn't yet translate. Skipped here with a clear reason — the engine
@@ -142,6 +144,17 @@ def __init__(
142144
self.saves: list[CheckpointRecord] = []
143145
self._abort_after_instance = abort_after_instance
144146
self._aborted = False
147+
# Per proposal 0029 (fixture 056): mutating the saved record's
148+
# outer state on ``load`` simulates "user shrank/grew the input
149+
# set between runs." The engine restores from this mutated
150+
# state, the fan-out node re-resolves count from the mutated
151+
# ``items``, and the count-drift check raises
152+
# ``checkpoint_record_invalid`` because the saved
153+
# ``fan_out_progress`` entry's ``instance_count`` doesn't match.
154+
# Keys are field names on the outer state; values replace
155+
# those fields when the record is returned to the engine on
156+
# resume.
157+
self.load_state_overrides: dict[str, Any] = {}
145158

146159
async def save(self, invocation_id: str, record: CheckpointRecord) -> None:
147160
self._raise_if_post_abort()
@@ -187,7 +200,26 @@ def _maybe_abort(self, record: CheckpointRecord) -> None:
187200
raise _AbortAfterInstance(f"simulated crash after instance {target_idx} completed save")
188201

189202
async def load(self, invocation_id: str) -> CheckpointRecord | None:
190-
return await self._inner.load(invocation_id)
203+
record = await self._inner.load(invocation_id)
204+
if record is None or not self.load_state_overrides:
205+
return record
206+
# Apply overrides to the outer state. For outer-level saves the
207+
# outer state is ``record.state``; for inner saves (fan-out
208+
# instance, subgraph) it's ``record.parent_states[0]``. Mutate
209+
# whichever shape is present so the test driver doesn't need
210+
# to care which save site landed last.
211+
from dataclasses import replace as dataclass_replace # noqa: PLC0415
212+
213+
if record.parent_states:
214+
outer = record.parent_states[0]
215+
outer_updates = {**outer.model_dump(), **self.load_state_overrides}
216+
new_outer = type(outer)(**outer_updates)
217+
new_parents = (new_outer,) + record.parent_states[1:]
218+
return dataclass_replace(record, parent_states=new_parents)
219+
outer = record.state
220+
outer_updates = {**outer.model_dump(), **self.load_state_overrides}
221+
new_outer = type(outer)(**outer_updates)
222+
return dataclass_replace(record, state=new_outer)
191223

192224
async def list(self, filter: Any = None) -> Any:
193225
return await self._inner.list(filter)
@@ -316,9 +348,46 @@ async def _run_one_case(spec: Mapping[str, Any], *, top_level: Mapping[str, Any]
316348
flaky_per_index_attempt_recorders=flaky_per_index_recorders,
317349
)
318350
builder = built.builder
351+
352+
# Per proposal 0028 (fixture 055): the fixture's ``state.schema_version``
353+
# directive declares the graph state class's schema_version, and the
354+
# optional ``runtime_state_subclass.schema_version`` directive
355+
# creates a subclass shadowing it. The harness applies both directly
356+
# to the constructed state class (build_state_cls in adapter.py
357+
# ignores schema_version today — supporting it via class-level
358+
# attribute writes here keeps the adapter signature stable).
359+
state_block = cast("Mapping[str, Any]", spec.get("state") or {})
360+
declared_schema_version = state_block.get("schema_version")
361+
if declared_schema_version is not None:
362+
built.state_cls.schema_version = str(declared_schema_version)
363+
319364
builder.with_checkpointer(cast("Checkpointer", capturing))
320365
compiled = builder.compile()
321-
initial_state = built.initial_state(spec.get("initial_state", {}))
366+
367+
# Per proposal 0028: ``runtime_state_subclass`` constructs a Python
368+
# subclass with the overridden ``schema_version`` and passes an
369+
# instance of THAT subclass to ``invoke()``. The test verifies the
370+
# engine ignores the subclass's value and writes saves using the
371+
# declared class's value — proving §10.2's "declared class is
372+
# canonical" rule.
373+
runtime_subclass_directive = cast(
374+
"Mapping[str, Any] | None",
375+
spec.get("runtime_state_subclass"),
376+
)
377+
if runtime_subclass_directive is not None:
378+
override_version = str(runtime_subclass_directive["schema_version"])
379+
# Subclass with ClassVar override at the class level. The
380+
# subclass IS-A built.state_cls (Pydantic structural-conformance
381+
# holds), so ``compiled.invoke(subclass_instance, ...)`` accepts
382+
# it without complaint.
383+
runtime_subclass = type(
384+
f"{built.state_cls.__name__}Runtime",
385+
(built.state_cls,),
386+
{"schema_version": override_version},
387+
)
388+
initial_state = cast("State", runtime_subclass(**spec.get("initial_state", {})))
389+
else:
390+
initial_state = built.initial_state(spec.get("initial_state", {}))
322391

323392
# Run #1 — first invocation. May succeed or fail per fixture.
324393
first_run_expected_error = spec.get("first_run_expected_error")
@@ -420,14 +489,39 @@ async def _run_one_case(spec: Mapping[str, Any], *, top_level: Mapping[str, Any]
420489
if "invariants" in expected:
421490
_assert_invariants(cast("Mapping[str, Any]", expected["invariants"]), capturing.saves)
422491

492+
# Per proposal 0028 (fixture 055): ``every_save_assertions`` is a
493+
# cross-save invariant block — every captured save during the
494+
# invocation MUST match every key in this block. Catches
495+
# implementations that read ``schema_version`` from
496+
# ``type(state)`` (the runtime subclass) at any intermediate save
497+
# site instead of from the declared graph state class. Distinct
498+
# from ``invariants`` above which asserts properties of the SET of
499+
# saves (e.g., "at least one save fired"); this asserts the same
500+
# property holds on EVERY save.
501+
every_save_block = cast(
502+
"Mapping[str, Any] | None",
503+
spec.get("every_save_assertions"),
504+
)
505+
if every_save_block is not None:
506+
assert capturing.saves, (
507+
"every_save_assertions declared but no saves were captured during the invocation"
508+
)
509+
for save_idx, saved_record in enumerate(capturing.saves):
510+
for key, expected_value in every_save_block.items():
511+
actual_value = getattr(saved_record, key, None)
512+
assert actual_value == expected_value, (
513+
f"every_save_assertions: save[{save_idx}].{key} mismatch — "
514+
f"actual={actual_value!r}, expected={expected_value!r}"
515+
)
516+
423517
# ----- checkpoint_not_found expected (fixture 030) -----
424518
if expected.get("expected_error") == "checkpoint_not_found":
425519
ghost = cast("str", expected.get("resume_invocation_id", "ghost"))
426520
with pytest.raises(CheckpointNotFound):
427521
await compiled.invoke(initial_state, resume_invocation=ghost)
428522
return
429523

430-
# ----- Resume path (fixtures 025, 029, 031, 048-054) -----
524+
# ----- Resume path (fixtures 025, 029, 031, 048-054, 056) -----
431525
resume_block = spec.get("resume")
432526
if resume_block is None or not resume_block.get("from_first_run"):
433527
return
@@ -448,6 +542,49 @@ async def _run_one_case(spec: Mapping[str, Any], *, top_level: Mapping[str, Any]
448542
capturing._abort_after_instance = None # noqa: SLF001
449543
# Clear the trace so post-resume execution capture is isolated.
450544
trace.clear()
545+
546+
# Per proposal 0029 (fixture 056): ``resume_with_modified_items``
547+
# simulates "user changed the input set between runs." The engine
548+
# restores state from the saved record on resume (the
549+
# ``initial_state`` parameter to ``invoke`` is ignored on the
550+
# resume path); to actually mutate the resumed run's state we
551+
# install overrides on the capturing checkpointer's ``load``
552+
# path, which patches the outer state when the engine reads back
553+
# the saved record. The fan-out node then re-resolves its count
554+
# from the mutated state and the count-drift check raises.
555+
modified_items_directive = cast(
556+
"Mapping[str, Any] | None",
557+
resume_block.get("resume_with_modified_items"),
558+
)
559+
if modified_items_directive is not None:
560+
capturing.load_state_overrides = dict(modified_items_directive)
561+
562+
# Per proposal 0029: a resume that hits count drift MUST raise
563+
# ``checkpoint_record_invalid``. ``resume.expected_error`` carries
564+
# the assertion (sibling to ``resume.expected``); when present, the
565+
# invoke MUST raise the named category before final_state can be
566+
# checked.
567+
resume_expected_error = cast(
568+
"Mapping[str, Any] | None",
569+
resume_block.get("expected_error"),
570+
)
571+
if resume_expected_error is not None:
572+
# CheckpointRecordInvalid (the proposal-0029 count-drift category)
573+
# is a CheckpointError, NOT a RuntimeGraphError — they're sibling
574+
# categorized error hierarchies. Catch the broader Exception and
575+
# assert ``category`` on the value to match both paths.
576+
with pytest.raises((CheckpointError, RuntimeGraphError)) as excinfo:
577+
await compiled.invoke(
578+
initial_state,
579+
resume_invocation=invocation_id_first_run,
580+
)
581+
expected_cat = resume_expected_error["category"]
582+
actual_cat = cast("str", getattr(excinfo.value, "category", ""))
583+
assert actual_cat == expected_cat, (
584+
f"resume expected_error category mismatch: actual={actual_cat!r}, expected={expected_cat!r}"
585+
)
586+
return
587+
451588
try:
452589
final_resume = await compiled.invoke(
453590
initial_state,
@@ -692,19 +829,25 @@ def _assert_fan_out_instance(
692829
f"fan_out_progress[{node_name!r}].instances[{idx}].result: "
693830
f"actual={actual.result!r}, expected={expected['result']!r}"
694831
)
695-
if expected.get("result_kind") == "error":
696-
# Spec §10.11.2: collect-mode error contributions are recorded
697-
# as the per-instance result entry. The engine ships
698-
# ``dict[str, str]`` with ``fan_out_index`` and ``category``.
699-
raw_result: Any = actual.result
700-
assert isinstance(raw_result, dict), (
701-
f"fan_out_progress[{node_name!r}].instances[{idx}].result: "
702-
f"expected dict (error_record), got {type(raw_result).__name__}"
832+
if "result_is_error" in expected:
833+
# Spec §10.11 (proposal 0027): explicit boolean discriminator
834+
# on the per-instance entry. Replaced the pre-0027
835+
# ``result_kind: error`` shape heuristic.
836+
assert actual.result_is_error == expected["result_is_error"], (
837+
f"fan_out_progress[{node_name!r}].instances[{idx}].result_is_error: "
838+
f"actual={actual.result_is_error!r}, expected={expected['result_is_error']!r}"
703839
)
704-
result_dict = cast("dict[str, Any]", raw_result)
705-
assert "category" in result_dict, (
706-
f"fan_out_progress[{node_name!r}].instances[{idx}].result: "
707-
f"expected error_record with 'category' key, got {result_dict!r}"
840+
if "result_present" in expected:
841+
# Spec §10.11 (proposal 0027): assert the ``result`` field
842+
# exists on the saved record without constraining its shape
843+
# (the value remains impl-defined per §9.5). Pair with
844+
# ``result_is_error: true`` to assert "an error contribution
845+
# was captured" without locking the test to one impl's error
846+
# record format.
847+
result_present_actual = actual.result is not None
848+
assert result_present_actual == expected["result_present"], (
849+
f"fan_out_progress[{node_name!r}].instances[{idx}].result_present: "
850+
f"actual={result_present_actual!r}, expected={expected['result_present']!r}"
708851
)
709852
if "completed_inner_positions" in expected:
710853
positions_expected = cast("list[Mapping[str, Any]]", expected["completed_inner_positions"])

0 commit comments

Comments
 (0)