Skip to content

Commit 021606e

Browse files
committed
Add Stage 1 checkpoint reuse boundary
1 parent 0924580 commit 021606e

13 files changed

Lines changed: 848 additions & 43 deletions

File tree

changelog.d/1074.added

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added Stage 1 checkpoint adapter and rerun reuse planning boundaries.

modal_app/data_build.py

Lines changed: 143 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import functools
24
import os
35
import shutil
@@ -22,6 +24,7 @@
2224
from modal_app.images import cpu_image as image # noqa: E402
2325
from policyengine_us_data.__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402
2426
from policyengine_us_data.build_datasets import ( # noqa: E402
27+
CheckpointStore,
2528
CommandRunner,
2629
DatasetCommand,
2730
DatasetCommandError,
@@ -30,6 +33,9 @@
3033
DatasetBuildOutputContractBuilder,
3134
PipelineArtifactStager,
3235
Stage1Coordinator,
36+
Stage1IdentityMaterial,
37+
Stage1RerunPlanner,
38+
Stage1ReuseDecision,
3339
stage_1_artifact_specs,
3440
stage_1_script_outputs,
3541
stage_1_substep_id_for_script,
@@ -185,30 +191,20 @@ def get_current_commit() -> str:
185191

186192
def get_checkpoint_path(branch: str, output_file: str) -> Path:
187193
"""Get the checkpoint path for an output file, scoped by branch and commit."""
188-
commit = get_current_commit()
189-
return Path(VOLUME_MOUNT) / branch / commit / Path(output_file).name
194+
return _checkpoint_store(branch).checkpoint_path(output_file)
190195

191196

192197
def is_checkpointed(branch: str, output_file: str) -> bool:
193198
"""Check if output file exists in checkpoint volume and is valid."""
194-
checkpoint_path = get_checkpoint_path(branch, output_file)
195-
if checkpoint_path.exists():
196-
# Verify file is not empty/corrupted
197-
if checkpoint_path.stat().st_size > 0:
198-
return True
199-
return False
199+
return _checkpoint_store(branch).decision_for(output_file).action == "reuse"
200200

201201

202202
def restore_from_checkpoint(branch: str, output_file: str) -> bool:
203203
"""Restore output file from checkpoint volume if it exists."""
204-
checkpoint_path = get_checkpoint_path(branch, output_file)
205-
if checkpoint_path.exists() and checkpoint_path.stat().st_size > 0:
206-
local_path = Path(output_file)
207-
local_path.parent.mkdir(parents=True, exist_ok=True)
208-
shutil.copy2(checkpoint_path, local_path)
204+
restored = _checkpoint_store(branch).restore_output(output_file)
205+
if restored:
209206
print(f"Restored from checkpoint: {output_file}")
210-
return True
211-
return False
207+
return restored
212208

213209

214210
def save_checkpoint(
@@ -217,25 +213,35 @@ def save_checkpoint(
217213
volume: modal.Volume,
218214
) -> None:
219215
"""Save output file to checkpoint volume."""
220-
local_path = Path(output_file)
221-
if local_path.exists() and local_path.stat().st_size > 0:
222-
checkpoint_path = get_checkpoint_path(branch, output_file)
223-
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
224-
shutil.copy2(local_path, checkpoint_path)
225-
with _volume_lock:
226-
volume.commit()
216+
saved = _checkpoint_store(branch, volume).save_output(output_file)
217+
if saved:
227218
print(f"Checkpointed: {output_file}")
228219

229220

230221
def cleanup_checkpoints(branch: str, volume: modal.Volume) -> None:
231222
"""Delete checkpoints for this branch after successful completion."""
232-
branch_dir = Path(VOLUME_MOUNT) / branch
233-
if branch_dir.exists():
234-
shutil.rmtree(branch_dir)
235-
volume.commit()
223+
cleaned = _checkpoint_store(branch, volume).cleanup_branch()
224+
if cleaned:
236225
print(f"Cleaned up checkpoints for branch: {branch}")
237226

238227

228+
def _checkpoint_store(
229+
branch: str,
230+
volume: modal.Volume | None = None,
231+
) -> CheckpointStore:
232+
def commit_volume() -> None:
233+
if volume is not None:
234+
with _volume_lock:
235+
volume.commit()
236+
237+
return CheckpointStore(
238+
root=Path(VOLUME_MOUNT),
239+
branch=branch,
240+
commit_sha=get_current_commit(),
241+
commit=commit_volume if volume is not None else None,
242+
)
243+
244+
239245
def run_script_logged(
240246
cmd: list,
241247
log_file: IO,
@@ -368,6 +374,8 @@ def run_script_with_checkpoint(
368374
log_file: IO = None,
369375
checkpoint_stats: CheckpointStats | None = None,
370376
command_results: list[DatasetCommandResult] | None = None,
377+
checkpoint_store: CheckpointStore | None = None,
378+
reuse_decision: Stage1ReuseDecision | None = None,
371379
) -> str:
372380
"""Run script if output not checkpointed, then checkpoint result.
373381
@@ -387,14 +395,29 @@ def run_script_with_checkpoint(
387395
if isinstance(output_files, str):
388396
output_files = [output_files]
389397
expected_count = len(output_files)
398+
checkpoint_store = checkpoint_store or _checkpoint_store(branch, volume)
399+
reuse_decision = reuse_decision or _compat_reuse_decision(
400+
script_path=script_path,
401+
output_files=output_files,
402+
branch=branch,
403+
)
404+
if reuse_decision.action == "blocked":
405+
raise RuntimeError(
406+
"Stage 1 checkpoint reuse is blocked for "
407+
f"{script_path}: {reuse_decision.reason}"
408+
)
390409

391410
# Check if ALL outputs are checkpointed
392-
all_checkpointed = all(is_checkpointed(branch, f) for f in output_files)
411+
checkpoint_decisions = checkpoint_store.decisions_for(output_files)
412+
all_checkpointed = reuse_decision.action == "reuse" and all(
413+
decision.action == "reuse" for decision in checkpoint_decisions
414+
)
393415

394416
if all_checkpointed:
395417
# Restore all files from checkpoint
418+
checkpoint_store.restore_all_outputs(output_files)
396419
for output_file in output_files:
397-
restore_from_checkpoint(branch, output_file)
420+
print(f"Restored from checkpoint: {output_file}")
398421
print(f"Skipping {script_path} (restored from checkpoint)")
399422
if checkpoint_stats is not None:
400423
checkpoint_stats.record(
@@ -404,7 +427,7 @@ def run_script_with_checkpoint(
404427
return script_path
405428

406429
missing_or_invalid = sum(
407-
1 for output_file in output_files if not is_checkpointed(branch, output_file)
430+
1 for decision in checkpoint_decisions if decision.action != "reuse"
408431
)
409432

410433
# Run the script
@@ -418,7 +441,9 @@ def run_script_with_checkpoint(
418441

419442
# Checkpoint all outputs
420443
for output_file in output_files:
421-
save_checkpoint(branch, output_file, volume)
444+
saved = checkpoint_store.save_output(output_file)
445+
if saved:
446+
print(f"Checkpointed: {output_file}")
422447
if checkpoint_stats is not None:
423448
checkpoint_stats.record(
424449
expected_outputs=expected_count,
@@ -444,6 +469,60 @@ def _stage_base_artifact_paths(artifacts_dir: Path) -> tuple[Path, ...]:
444469
return tuple(paths)
445470

446471

472+
def _stage_1_status_metadata(coordinator: Stage1Coordinator) -> dict[str, Any]:
473+
return {
474+
"substep_results": [
475+
result.to_dict() for result in getattr(coordinator, "results", ())
476+
],
477+
"status_events": [
478+
event.to_dict() for event in getattr(coordinator, "status_events", ())
479+
],
480+
"error_records": [
481+
error.to_dict() for error in getattr(coordinator, "error_records", ())
482+
],
483+
}
484+
485+
486+
def _compat_reuse_decision(
487+
*,
488+
script_path: str,
489+
output_files: list[str],
490+
branch: str,
491+
run_id: str | None = None,
492+
rerun_id: str | None = None,
493+
) -> Stage1ReuseDecision:
494+
"""Return the current compatible semantic reuse decision for a script."""
495+
496+
substep_id = stage_1_substep_id_for_script(script_path)
497+
material = Stage1IdentityMaterial(
498+
substep_id=substep_id,
499+
inputs={"script_path": script_path},
500+
parameters={"branch": branch, "outputs": output_files},
501+
artifact_specs=[
502+
{
503+
"filename": spec.filename,
504+
"logical_name": spec.logical_name,
505+
"storage_path": spec.storage_path,
506+
"substage_id": spec.substage_id,
507+
}
508+
for spec in stage_1_artifact_specs()
509+
if spec.script_path == script_path
510+
],
511+
code_sha=get_current_commit(),
512+
schema_version="stage-1-rerun-v1",
513+
upstream_contract_fingerprints=(),
514+
randomness={"checkpoint_scope": "branch_commit"},
515+
)
516+
planner = Stage1RerunPlanner(
517+
previous_identities={substep_id: material.fingerprint()}
518+
)
519+
return planner.decide(
520+
material,
521+
run_id=run_id or os.environ.get(RUN_ID_ENV, "unknown"),
522+
rerun_id=rerun_id,
523+
)
524+
525+
447526
def _run_checkpointed_substep(
448527
*,
449528
coordinator: Stage1Coordinator | None,
@@ -456,6 +535,26 @@ def _run_checkpointed_substep(
456535
checkpoint_stats: CheckpointStats | None = None,
457536
) -> str:
458537
command_results: list[DatasetCommandResult] = []
538+
output_list = output_files if isinstance(output_files, list) else [output_files]
539+
if coordinator is None:
540+
return run_script_with_checkpoint(
541+
script_path,
542+
output_files,
543+
branch,
544+
volume,
545+
env=env,
546+
log_file=log_file,
547+
checkpoint_stats=checkpoint_stats,
548+
command_results=command_results,
549+
)
550+
551+
checkpoint_store = _checkpoint_store(branch, volume)
552+
reuse_decision = _compat_reuse_decision(
553+
script_path=script_path,
554+
output_files=output_list,
555+
branch=branch,
556+
)
557+
checkpoint_decisions = checkpoint_store.decisions_for(output_list)
459558

460559
def action() -> str:
461560
return run_script_with_checkpoint(
@@ -466,11 +565,10 @@ def action() -> str:
466565
env=env,
467566
log_file=log_file,
468567
checkpoint_stats=checkpoint_stats,
469-
command_results=command_results,
568+
checkpoint_store=checkpoint_store,
569+
reuse_decision=reuse_decision,
470570
)
471571

472-
if coordinator is None:
473-
return action()
474572
substep_id = stage_1_substep_id_for_script(script_path)
475573
return coordinator.run_substep(
476574
substep_id,
@@ -479,6 +577,10 @@ def action() -> str:
479577
command_names=(script_path,),
480578
command_results=command_results,
481579
artifact_paths=_output_paths(output_files),
580+
reuse_decision=reuse_decision.to_dict(),
581+
checkpoint_decisions=tuple(
582+
decision.to_dict() for decision in checkpoint_decisions
583+
),
482584
aggregate=True,
483585
)
484586

@@ -583,6 +685,7 @@ def write_dataset_build_contract(
583685
package_version: str = DATA_PACKAGE_VERSION,
584686
branch: str = "unknown",
585687
diagnostics: tuple = (),
688+
stage_1_status_metadata: Mapping[str, Any] | None = None,
586689
) -> StageContract:
587690
"""Write the Stage 1 semantic handoff contract next to copied artifacts."""
588691
context = DatasetBuildContext(
@@ -602,6 +705,7 @@ def write_dataset_build_contract(
602705
skip_enhanced_cps=skip_enhanced_cps,
603706
skip_stage_5=skip_stage_5,
604707
diagnostics=diagnostics,
708+
stage_1_status_metadata=stage_1_status_metadata,
605709
)
606710

607711

@@ -691,14 +795,11 @@ def build_datasets(
691795
os.chdir("/root/policyengine-us-data")
692796

693797
# Clean stale checkpoints from other commits
694-
branch_dir = Path(VOLUME_MOUNT) / branch
695-
if branch_dir.exists():
696-
current_commit = get_current_commit()
697-
for entry in branch_dir.iterdir():
698-
if entry.is_dir() and entry.name != current_commit:
699-
shutil.rmtree(entry)
700-
print(f"Removed stale checkpoint dir: {entry.name[:12]}")
701-
checkpoint_volume.commit()
798+
for removed_checkpoint in _checkpoint_store(
799+
branch,
800+
checkpoint_volume,
801+
).cleanup_other_commits():
802+
print(f"Removed stale checkpoint dir: {removed_checkpoint.name[:12]}")
702803

703804
# Open persistent build log with provenance header
704805
commit = get_current_commit()
@@ -1032,6 +1133,7 @@ def run_stage_base_handoff() -> None:
10321133
package_version=version,
10331134
branch=branch,
10341135
diagnostics=diagnostics,
1136+
stage_1_status_metadata=_stage_1_status_metadata(coordinator),
10351137
)
10361138
pipeline_volume.commit()
10371139
print("Pipeline artifacts committed to shared volume")

policyengine_us_data/build_datasets/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
stage_1_pipeline_artifact_specs,
1010
stage_1_script_outputs,
1111
)
12+
from .checkpoints import (
13+
CheckpointDecision,
14+
CheckpointReuseSummary,
15+
CheckpointStore,
16+
)
1217
from .commands import (
1318
CommandRunner,
1419
DatasetCommand,
@@ -38,11 +43,19 @@
3843
stage_1_step_specs,
3944
)
4045
from .results import DatasetCommandResult, DatasetSubstepResult
46+
from .rerun import (
47+
Stage1IdentityMaterial,
48+
Stage1RerunPlanner,
49+
Stage1ReuseDecision,
50+
)
4151
from .staging import PipelineArtifactStager
4252
from .status import Stage1ErrorRecord, Stage1StatusEvent
4353

4454
__all__ = [
4555
"ARTIFACT_SCHEMA_VERSION",
56+
"CheckpointDecision",
57+
"CheckpointReuseSummary",
58+
"CheckpointStore",
4659
"CommandBackedSubstepRunner",
4760
"CommandRunner",
4861
"DatasetArtifactSpec",
@@ -61,6 +74,9 @@
6174
"SourceDatasetSchemaSummaryWriter",
6275
"Stage1Coordinator",
6376
"Stage1ErrorRecord",
77+
"Stage1IdentityMaterial",
78+
"Stage1RerunPlanner",
79+
"Stage1ReuseDecision",
6480
"Stage1StatusEvent",
6581
"Stage1SubstepRunner",
6682
"SubprocessLogCapture",

0 commit comments

Comments
 (0)