1+ from __future__ import annotations
2+
13import functools
24import os
35import shutil
2224from modal_app .images import cpu_image as image # noqa: E402
2325from policyengine_us_data .__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402
2426from policyengine_us_data .build_datasets import ( # noqa: E402
27+ CheckpointStore ,
2528 CommandRunner ,
2629 DatasetCommand ,
2730 DatasetCommandError ,
3033 DatasetBuildOutputContractBuilder ,
3134 PipelineArtifactStager ,
3235 Stage1Coordinator ,
36+ Stage1IdentityMaterial ,
37+ Stage1RerunPlanner ,
38+ Stage1ReuseDecision ,
3339 stage_1_artifact_specs ,
3440 stage_1_script_outputs ,
3541 stage_1_substep_id_for_script ,
@@ -185,30 +191,20 @@ def get_current_commit() -> str:
185191
186192def get_checkpoint_path (branch : str , output_file : str ) -> Path :
187193 """Get the checkpoint path for an output file, scoped by branch and commit."""
188- commit = get_current_commit ()
189- return Path (VOLUME_MOUNT ) / branch / commit / Path (output_file ).name
194+ return _checkpoint_store (branch ).checkpoint_path (output_file )
190195
191196
192197def is_checkpointed (branch : str , output_file : str ) -> bool :
193198 """Check if output file exists in checkpoint volume and is valid."""
194- checkpoint_path = get_checkpoint_path (branch , output_file )
195- if checkpoint_path .exists ():
196- # Verify file is not empty/corrupted
197- if checkpoint_path .stat ().st_size > 0 :
198- return True
199- return False
199+ return _checkpoint_store (branch ).decision_for (output_file ).action == "reuse"
200200
201201
202202def restore_from_checkpoint (branch : str , output_file : str ) -> bool :
203203 """Restore output file from checkpoint volume if it exists."""
204- checkpoint_path = get_checkpoint_path (branch , output_file )
205- if checkpoint_path .exists () and checkpoint_path .stat ().st_size > 0 :
206- local_path = Path (output_file )
207- local_path .parent .mkdir (parents = True , exist_ok = True )
208- shutil .copy2 (checkpoint_path , local_path )
204+ restored = _checkpoint_store (branch ).restore_output (output_file )
205+ if restored :
209206 print (f"Restored from checkpoint: { output_file } " )
210- return True
211- return False
207+ return restored
212208
213209
214210def save_checkpoint (
@@ -217,25 +213,35 @@ def save_checkpoint(
217213 volume : modal .Volume ,
218214) -> None :
219215 """Save output file to checkpoint volume."""
220- local_path = Path (output_file )
221- if local_path .exists () and local_path .stat ().st_size > 0 :
222- checkpoint_path = get_checkpoint_path (branch , output_file )
223- checkpoint_path .parent .mkdir (parents = True , exist_ok = True )
224- shutil .copy2 (local_path , checkpoint_path )
225- with _volume_lock :
226- volume .commit ()
216+ saved = _checkpoint_store (branch , volume ).save_output (output_file )
217+ if saved :
227218 print (f"Checkpointed: { output_file } " )
228219
229220
230221def cleanup_checkpoints (branch : str , volume : modal .Volume ) -> None :
231222 """Delete checkpoints for this branch after successful completion."""
232- branch_dir = Path (VOLUME_MOUNT ) / branch
233- if branch_dir .exists ():
234- shutil .rmtree (branch_dir )
235- volume .commit ()
223+ cleaned = _checkpoint_store (branch , volume ).cleanup_branch ()
224+ if cleaned :
236225 print (f"Cleaned up checkpoints for branch: { branch } " )
237226
238227
228+ def _checkpoint_store (
229+ branch : str ,
230+ volume : modal .Volume | None = None ,
231+ ) -> CheckpointStore :
232+ def commit_volume () -> None :
233+ if volume is not None :
234+ with _volume_lock :
235+ volume .commit ()
236+
237+ return CheckpointStore (
238+ root = Path (VOLUME_MOUNT ),
239+ branch = branch ,
240+ commit_sha = get_current_commit (),
241+ commit = commit_volume if volume is not None else None ,
242+ )
243+
244+
239245def run_script_logged (
240246 cmd : list ,
241247 log_file : IO ,
@@ -368,6 +374,8 @@ def run_script_with_checkpoint(
368374 log_file : IO = None ,
369375 checkpoint_stats : CheckpointStats | None = None ,
370376 command_results : list [DatasetCommandResult ] | None = None ,
377+ checkpoint_store : CheckpointStore | None = None ,
378+ reuse_decision : Stage1ReuseDecision | None = None ,
371379) -> str :
372380 """Run script if output not checkpointed, then checkpoint result.
373381
@@ -387,14 +395,29 @@ def run_script_with_checkpoint(
387395 if isinstance (output_files , str ):
388396 output_files = [output_files ]
389397 expected_count = len (output_files )
398+ checkpoint_store = checkpoint_store or _checkpoint_store (branch , volume )
399+ reuse_decision = reuse_decision or _compat_reuse_decision (
400+ script_path = script_path ,
401+ output_files = output_files ,
402+ branch = branch ,
403+ )
404+ if reuse_decision .action == "blocked" :
405+ raise RuntimeError (
406+ "Stage 1 checkpoint reuse is blocked for "
407+ f"{ script_path } : { reuse_decision .reason } "
408+ )
390409
391410 # Check if ALL outputs are checkpointed
392- all_checkpointed = all (is_checkpointed (branch , f ) for f in output_files )
411+ checkpoint_decisions = checkpoint_store .decisions_for (output_files )
412+ all_checkpointed = reuse_decision .action == "reuse" and all (
413+ decision .action == "reuse" for decision in checkpoint_decisions
414+ )
393415
394416 if all_checkpointed :
395417 # Restore all files from checkpoint
418+ checkpoint_store .restore_all_outputs (output_files )
396419 for output_file in output_files :
397- restore_from_checkpoint ( branch , output_file )
420+ print ( f"Restored from checkpoint: { output_file } " )
398421 print (f"Skipping { script_path } (restored from checkpoint)" )
399422 if checkpoint_stats is not None :
400423 checkpoint_stats .record (
@@ -404,7 +427,7 @@ def run_script_with_checkpoint(
404427 return script_path
405428
406429 missing_or_invalid = sum (
407- 1 for output_file in output_files if not is_checkpointed ( branch , output_file )
430+ 1 for decision in checkpoint_decisions if decision . action != "reuse"
408431 )
409432
410433 # Run the script
@@ -418,7 +441,9 @@ def run_script_with_checkpoint(
418441
419442 # Checkpoint all outputs
420443 for output_file in output_files :
421- save_checkpoint (branch , output_file , volume )
444+ saved = checkpoint_store .save_output (output_file )
445+ if saved :
446+ print (f"Checkpointed: { output_file } " )
422447 if checkpoint_stats is not None :
423448 checkpoint_stats .record (
424449 expected_outputs = expected_count ,
@@ -444,6 +469,60 @@ def _stage_base_artifact_paths(artifacts_dir: Path) -> tuple[Path, ...]:
444469 return tuple (paths )
445470
446471
472+ def _stage_1_status_metadata (coordinator : Stage1Coordinator ) -> dict [str , Any ]:
473+ return {
474+ "substep_results" : [
475+ result .to_dict () for result in getattr (coordinator , "results" , ())
476+ ],
477+ "status_events" : [
478+ event .to_dict () for event in getattr (coordinator , "status_events" , ())
479+ ],
480+ "error_records" : [
481+ error .to_dict () for error in getattr (coordinator , "error_records" , ())
482+ ],
483+ }
484+
485+
486+ def _compat_reuse_decision (
487+ * ,
488+ script_path : str ,
489+ output_files : list [str ],
490+ branch : str ,
491+ run_id : str | None = None ,
492+ rerun_id : str | None = None ,
493+ ) -> Stage1ReuseDecision :
494+ """Return the current compatible semantic reuse decision for a script."""
495+
496+ substep_id = stage_1_substep_id_for_script (script_path )
497+ material = Stage1IdentityMaterial (
498+ substep_id = substep_id ,
499+ inputs = {"script_path" : script_path },
500+ parameters = {"branch" : branch , "outputs" : output_files },
501+ artifact_specs = [
502+ {
503+ "filename" : spec .filename ,
504+ "logical_name" : spec .logical_name ,
505+ "storage_path" : spec .storage_path ,
506+ "substage_id" : spec .substage_id ,
507+ }
508+ for spec in stage_1_artifact_specs ()
509+ if spec .script_path == script_path
510+ ],
511+ code_sha = get_current_commit (),
512+ schema_version = "stage-1-rerun-v1" ,
513+ upstream_contract_fingerprints = (),
514+ randomness = {"checkpoint_scope" : "branch_commit" },
515+ )
516+ planner = Stage1RerunPlanner (
517+ previous_identities = {substep_id : material .fingerprint ()}
518+ )
519+ return planner .decide (
520+ material ,
521+ run_id = run_id or os .environ .get (RUN_ID_ENV , "unknown" ),
522+ rerun_id = rerun_id ,
523+ )
524+
525+
447526def _run_checkpointed_substep (
448527 * ,
449528 coordinator : Stage1Coordinator | None ,
@@ -456,6 +535,26 @@ def _run_checkpointed_substep(
456535 checkpoint_stats : CheckpointStats | None = None ,
457536) -> str :
458537 command_results : list [DatasetCommandResult ] = []
538+ output_list = output_files if isinstance (output_files , list ) else [output_files ]
539+ if coordinator is None :
540+ return run_script_with_checkpoint (
541+ script_path ,
542+ output_files ,
543+ branch ,
544+ volume ,
545+ env = env ,
546+ log_file = log_file ,
547+ checkpoint_stats = checkpoint_stats ,
548+ command_results = command_results ,
549+ )
550+
551+ checkpoint_store = _checkpoint_store (branch , volume )
552+ reuse_decision = _compat_reuse_decision (
553+ script_path = script_path ,
554+ output_files = output_list ,
555+ branch = branch ,
556+ )
557+ checkpoint_decisions = checkpoint_store .decisions_for (output_list )
459558
460559 def action () -> str :
461560 return run_script_with_checkpoint (
@@ -466,11 +565,10 @@ def action() -> str:
466565 env = env ,
467566 log_file = log_file ,
468567 checkpoint_stats = checkpoint_stats ,
469- command_results = command_results ,
568+ checkpoint_store = checkpoint_store ,
569+ reuse_decision = reuse_decision ,
470570 )
471571
472- if coordinator is None :
473- return action ()
474572 substep_id = stage_1_substep_id_for_script (script_path )
475573 return coordinator .run_substep (
476574 substep_id ,
@@ -479,6 +577,10 @@ def action() -> str:
479577 command_names = (script_path ,),
480578 command_results = command_results ,
481579 artifact_paths = _output_paths (output_files ),
580+ reuse_decision = reuse_decision .to_dict (),
581+ checkpoint_decisions = tuple (
582+ decision .to_dict () for decision in checkpoint_decisions
583+ ),
482584 aggregate = True ,
483585 )
484586
@@ -583,6 +685,7 @@ def write_dataset_build_contract(
583685 package_version : str = DATA_PACKAGE_VERSION ,
584686 branch : str = "unknown" ,
585687 diagnostics : tuple = (),
688+ stage_1_status_metadata : Mapping [str , Any ] | None = None ,
586689) -> StageContract :
587690 """Write the Stage 1 semantic handoff contract next to copied artifacts."""
588691 context = DatasetBuildContext (
@@ -602,6 +705,7 @@ def write_dataset_build_contract(
602705 skip_enhanced_cps = skip_enhanced_cps ,
603706 skip_stage_5 = skip_stage_5 ,
604707 diagnostics = diagnostics ,
708+ stage_1_status_metadata = stage_1_status_metadata ,
605709 )
606710
607711
@@ -691,14 +795,11 @@ def build_datasets(
691795 os .chdir ("/root/policyengine-us-data" )
692796
693797 # Clean stale checkpoints from other commits
694- branch_dir = Path (VOLUME_MOUNT ) / branch
695- if branch_dir .exists ():
696- current_commit = get_current_commit ()
697- for entry in branch_dir .iterdir ():
698- if entry .is_dir () and entry .name != current_commit :
699- shutil .rmtree (entry )
700- print (f"Removed stale checkpoint dir: { entry .name [:12 ]} " )
701- checkpoint_volume .commit ()
798+ for removed_checkpoint in _checkpoint_store (
799+ branch ,
800+ checkpoint_volume ,
801+ ).cleanup_other_commits ():
802+ print (f"Removed stale checkpoint dir: { removed_checkpoint .name [:12 ]} " )
702803
703804 # Open persistent build log with provenance header
704805 commit = get_current_commit ()
@@ -1032,6 +1133,7 @@ def run_stage_base_handoff() -> None:
10321133 package_version = version ,
10331134 branch = branch ,
10341135 diagnostics = diagnostics ,
1136+ stage_1_status_metadata = _stage_1_status_metadata (coordinator ),
10351137 )
10361138 pipeline_volume .commit ()
10371139 print ("Pipeline artifacts committed to shared volume" )
0 commit comments