11import functools
2- import json
32import os
43import shutil
54import subprocess
2221
2322from modal_app .images import cpu_image as image # noqa: E402
2423from policyengine_us_data .__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402
25- from policyengine_us_data .build_datasets import stage_1_script_outputs # noqa: E402
24+ from policyengine_us_data .build_datasets import ( # noqa: E402
25+ DatasetBuildContext ,
26+ DatasetBuildOutputContractBuilder ,
27+ PipelineArtifactStager ,
28+ stage_1_script_outputs ,
29+ write_stage_1_diagnostics ,
30+ )
2631from policyengine_us_data .pipeline_metadata import pipeline_node # noqa: E402
2732from policyengine_us_data .pipeline_schema import PipelineNode # noqa: E402
2833from policyengine_us_data .stage_contracts import ( # noqa: E402
29- DATASET_BUILD_OUTPUT_CONTRACT_FILENAME ,
3034 StageContract ,
31- build_dataset_build_output_contract ,
32- write_contract ,
3335)
3436from policyengine_us_data .utils .run_context import ( # noqa: E402
3537 CANDIDATE_VERSION_ENV ,
@@ -484,13 +486,18 @@ def write_dataset_build_contract(
484486 skip_enhanced_cps : bool ,
485487 skip_stage_5 : bool = False ,
486488 package_version : str = DATA_PACKAGE_VERSION ,
489+ branch : str = "unknown" ,
490+ diagnostics : tuple = (),
487491) -> StageContract :
488492 """Write the Stage 1 semantic handoff contract next to copied artifacts."""
489- contract = build_dataset_build_output_contract (
490- artifacts_dir = artifacts_dir ,
493+ context = DatasetBuildContext (
491494 run_id = run_id ,
495+ branch = branch ,
492496 code_sha = code_sha ,
493497 package_version = package_version ,
498+ artifacts_dir = artifacts_dir ,
499+ )
500+ return DatasetBuildOutputContractBuilder (context = context ).write (
494501 checkpoint_stats = checkpoint_stats ,
495502 started_at = started_at ,
496503 completed_at = completed_at ,
@@ -499,12 +506,8 @@ def write_dataset_build_contract(
499506 stage_only = stage_only ,
500507 skip_enhanced_cps = skip_enhanced_cps ,
501508 skip_stage_5 = skip_stage_5 ,
509+ diagnostics = diagnostics ,
502510 )
503- write_contract (
504- contract ,
505- artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME ,
506- )
507- return contract
508511
509512
510513@app .function (
@@ -529,7 +532,15 @@ def write_dataset_build_contract(
529532 status = "current" ,
530533 stability = "moving" ,
531534 pathways = ["data_build" , "orchestration" ],
532- artifacts_out = ["source_imputed_*.h5" , "policy_data.db" ],
535+ artifacts_out = [
536+ "dataset_build_output.json" ,
537+ "dataset_inventory.json" ,
538+ "source_dataset_schema_summary.json" ,
539+ "target_database_schema_summary.json" ,
540+ "source_imputed_stratified_extended_cps_2024.h5" ,
541+ "source_imputed_stratified_extended_cps.h5" ,
542+ "policy_data.db" ,
543+ ],
533544 validation_commands = ["uv run pytest tests/unit/test_modal_data_build.py" ],
534545 )
535546)
@@ -810,41 +821,32 @@ def build_datasets(
810821 artifacts_dir = Path (PIPELINE_MOUNT ) / "artifacts"
811822 if run_id :
812823 artifacts_dir = artifacts_dir / run_id
813- artifacts_dir .mkdir (parents = True , exist_ok = True )
814-
815- # Copy all intermediate H5 datasets for lineage tracing
816- for output in SCRIPT_OUTPUTS .values ():
817- paths = output if isinstance (output , list ) else [output ]
818- for p in paths :
819- src = Path (p )
820- if src .suffix == ".h5" and src .exists ():
821- shutil .copy2 (src , artifacts_dir / src .name )
822- print (
823- f" Copied { src .name } ({ src .stat ().st_size / 1024 / 1024 :.1f} MB)"
824- )
825-
826- # Yearless alias for pipeline consumers (remote_calibration_runner, local_area)
827- si = artifacts_dir / "source_imputed_stratified_extended_cps_2024.h5"
828- if si .exists ():
829- shutil .copy2 (si , artifacts_dir / "source_imputed_stratified_extended_cps.h5" )
830-
831- shutil .copy2 (
832- "policyengine_us_data/storage/calibration/policy_data.db" ,
833- artifacts_dir / "policy_data.db" ,
824+ build_context = DatasetBuildContext (
825+ run_id = run_id ,
826+ branch = branch ,
827+ code_sha = commit ,
828+ package_version = version ,
829+ artifacts_dir = artifacts_dir ,
834830 )
835- cal_weights = Path ("policyengine_us_data/storage/calibration_weights.npy" )
836- if cal_weights .exists ():
837- shutil .copy2 (
838- cal_weights ,
839- artifacts_dir / "calibration_weights.npy" ,
831+ stager = PipelineArtifactStager (context = build_context )
832+ staged_paths = stager .stage_declared_artifacts (
833+ skip_enhanced_cps = skip_enhanced_cps ,
834+ skip_stage_5 = skip_stage_5 ,
835+ )
836+ for staged_path in staged_paths :
837+ print (
838+ f" Copied { staged_path .name } "
839+ f"({ staged_path .stat ().st_size / 1024 / 1024 :.1f} MB)"
840840 )
841- print (" Copied calibration_weights.npy" )
842- shutil .copy2 (log_path , artifacts_dir / "build_log.txt" )
843841 checkpoint_snapshot = checkpoint_stats .snapshot ()
844- with open (artifacts_dir / "data_build_checkpoint_stats.json" , "w" ) as f :
845- json .dump (checkpoint_snapshot , f , indent = 2 , sort_keys = True )
842+ stager .write_checkpoint_stats (checkpoint_snapshot )
846843 log_file .close ()
847844 completed_at_dt = datetime .now (timezone .utc )
845+ diagnostics = write_stage_1_diagnostics (
846+ context = build_context ,
847+ skip_enhanced_cps = skip_enhanced_cps ,
848+ skip_stage_5 = skip_stage_5 ,
849+ )
848850 write_dataset_build_contract (
849851 artifacts_dir = artifacts_dir ,
850852 run_id = run_id ,
@@ -858,6 +860,8 @@ def build_datasets(
858860 skip_enhanced_cps = skip_enhanced_cps ,
859861 skip_stage_5 = skip_stage_5 ,
860862 package_version = version ,
863+ branch = branch ,
864+ diagnostics = diagnostics ,
861865 )
862866 pipeline_volume .commit ()
863867 print ("Pipeline artifacts committed to shared volume" )
0 commit comments