Skip to content

Commit c968309

Browse files
authored
Merge pull request #1050 from PolicyEngine/agent/stage-1/pr-2-build-context-artifact-staging
Add Stage 1 artifact staging boundary
2 parents 275a1d3 + 9877333 commit c968309

15 files changed

Lines changed: 1010 additions & 48 deletions

changelog.d/1048.added

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added a Stage 1 dataset-build context, artifact stager, and diagnostic artifact writers for the pipeline handoff.

modal_app/data_build.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import json
32
import os
43
import shutil
54
import subprocess
@@ -22,14 +21,17 @@
2221

2322
from modal_app.images import cpu_image as image # noqa: E402
2423
from policyengine_us_data.__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402
25-
from policyengine_us_data.build_datasets import stage_1_script_outputs # noqa: E402
24+
from policyengine_us_data.build_datasets import ( # noqa: E402
25+
DatasetBuildContext,
26+
DatasetBuildOutputContractBuilder,
27+
PipelineArtifactStager,
28+
stage_1_script_outputs,
29+
write_stage_1_diagnostics,
30+
)
2631
from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402
2732
from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402
2833
from policyengine_us_data.stage_contracts import ( # noqa: E402
29-
DATASET_BUILD_OUTPUT_CONTRACT_FILENAME,
3034
StageContract,
31-
build_dataset_build_output_contract,
32-
write_contract,
3335
)
3436
from policyengine_us_data.utils.run_context import ( # noqa: E402
3537
CANDIDATE_VERSION_ENV,
@@ -484,13 +486,18 @@ def write_dataset_build_contract(
484486
skip_enhanced_cps: bool,
485487
skip_stage_5: bool = False,
486488
package_version: str = DATA_PACKAGE_VERSION,
489+
branch: str = "unknown",
490+
diagnostics: tuple = (),
487491
) -> StageContract:
488492
"""Write the Stage 1 semantic handoff contract next to copied artifacts."""
489-
contract = build_dataset_build_output_contract(
490-
artifacts_dir=artifacts_dir,
493+
context = DatasetBuildContext(
491494
run_id=run_id,
495+
branch=branch,
492496
code_sha=code_sha,
493497
package_version=package_version,
498+
artifacts_dir=artifacts_dir,
499+
)
500+
return DatasetBuildOutputContractBuilder(context=context).write(
494501
checkpoint_stats=checkpoint_stats,
495502
started_at=started_at,
496503
completed_at=completed_at,
@@ -499,12 +506,8 @@ def write_dataset_build_contract(
499506
stage_only=stage_only,
500507
skip_enhanced_cps=skip_enhanced_cps,
501508
skip_stage_5=skip_stage_5,
509+
diagnostics=diagnostics,
502510
)
503-
write_contract(
504-
contract,
505-
artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME,
506-
)
507-
return contract
508511

509512

510513
@app.function(
@@ -529,7 +532,15 @@ def write_dataset_build_contract(
529532
status="current",
530533
stability="moving",
531534
pathways=["data_build", "orchestration"],
532-
artifacts_out=["source_imputed_*.h5", "policy_data.db"],
535+
artifacts_out=[
536+
"dataset_build_output.json",
537+
"dataset_inventory.json",
538+
"source_dataset_schema_summary.json",
539+
"target_database_schema_summary.json",
540+
"source_imputed_stratified_extended_cps_2024.h5",
541+
"source_imputed_stratified_extended_cps.h5",
542+
"policy_data.db",
543+
],
533544
validation_commands=["uv run pytest tests/unit/test_modal_data_build.py"],
534545
)
535546
)
@@ -810,41 +821,32 @@ def build_datasets(
810821
artifacts_dir = Path(PIPELINE_MOUNT) / "artifacts"
811822
if run_id:
812823
artifacts_dir = artifacts_dir / run_id
813-
artifacts_dir.mkdir(parents=True, exist_ok=True)
814-
815-
# Copy all intermediate H5 datasets for lineage tracing
816-
for output in SCRIPT_OUTPUTS.values():
817-
paths = output if isinstance(output, list) else [output]
818-
for p in paths:
819-
src = Path(p)
820-
if src.suffix == ".h5" and src.exists():
821-
shutil.copy2(src, artifacts_dir / src.name)
822-
print(
823-
f" Copied {src.name} ({src.stat().st_size / 1024 / 1024:.1f} MB)"
824-
)
825-
826-
# Yearless alias for pipeline consumers (remote_calibration_runner, local_area)
827-
si = artifacts_dir / "source_imputed_stratified_extended_cps_2024.h5"
828-
if si.exists():
829-
shutil.copy2(si, artifacts_dir / "source_imputed_stratified_extended_cps.h5")
830-
831-
shutil.copy2(
832-
"policyengine_us_data/storage/calibration/policy_data.db",
833-
artifacts_dir / "policy_data.db",
824+
build_context = DatasetBuildContext(
825+
run_id=run_id,
826+
branch=branch,
827+
code_sha=commit,
828+
package_version=version,
829+
artifacts_dir=artifacts_dir,
834830
)
835-
cal_weights = Path("policyengine_us_data/storage/calibration_weights.npy")
836-
if cal_weights.exists():
837-
shutil.copy2(
838-
cal_weights,
839-
artifacts_dir / "calibration_weights.npy",
831+
stager = PipelineArtifactStager(context=build_context)
832+
staged_paths = stager.stage_declared_artifacts(
833+
skip_enhanced_cps=skip_enhanced_cps,
834+
skip_stage_5=skip_stage_5,
835+
)
836+
for staged_path in staged_paths:
837+
print(
838+
f" Copied {staged_path.name} "
839+
f"({staged_path.stat().st_size / 1024 / 1024:.1f} MB)"
840840
)
841-
print(" Copied calibration_weights.npy")
842-
shutil.copy2(log_path, artifacts_dir / "build_log.txt")
843841
checkpoint_snapshot = checkpoint_stats.snapshot()
844-
with open(artifacts_dir / "data_build_checkpoint_stats.json", "w") as f:
845-
json.dump(checkpoint_snapshot, f, indent=2, sort_keys=True)
842+
stager.write_checkpoint_stats(checkpoint_snapshot)
846843
log_file.close()
847844
completed_at_dt = datetime.now(timezone.utc)
845+
diagnostics = write_stage_1_diagnostics(
846+
context=build_context,
847+
skip_enhanced_cps=skip_enhanced_cps,
848+
skip_stage_5=skip_stage_5,
849+
)
848850
write_dataset_build_contract(
849851
artifacts_dir=artifacts_dir,
850852
run_id=run_id,
@@ -858,6 +860,8 @@ def build_datasets(
858860
skip_enhanced_cps=skip_enhanced_cps,
859861
skip_stage_5=skip_stage_5,
860862
package_version=version,
863+
branch=branch,
864+
diagnostics=diagnostics,
861865
)
862866
pipeline_volume.commit()
863867
print("Pipeline artifacts committed to shared volume")

policyengine_us_data/build_datasets/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,45 @@
55
STAGE_1_ARTIFACT_SPECS,
66
stage_1_artifact_specs,
77
stage_1_contract_artifact_specs,
8+
stage_1_diagnostic_artifact_specs,
9+
stage_1_pipeline_artifact_specs,
810
stage_1_script_outputs,
911
)
12+
from .context import DatasetBuildContext
13+
from .contracts import DatasetBuildOutputContractBuilder
14+
from .diagnostics import (
15+
ARTIFACT_SCHEMA_VERSION,
16+
DatasetInventoryWriter,
17+
SourceDatasetSchemaSummaryWriter,
18+
TargetDatabaseSchemaSummaryWriter,
19+
write_stage_1_diagnostics,
20+
)
1021
from .specs import (
1122
DatasetBuildStepSpec,
1223
STAGE_1_BUILD_DATASETS,
1324
STAGE_1_BUILD_STEP_SPECS,
1425
stage_1_step_specs,
1526
)
27+
from .staging import PipelineArtifactStager
1628

1729
__all__ = [
30+
"ARTIFACT_SCHEMA_VERSION",
1831
"DatasetArtifactSpec",
32+
"DatasetBuildContext",
33+
"DatasetBuildOutputContractBuilder",
1934
"DatasetBuildStepSpec",
35+
"DatasetInventoryWriter",
36+
"PipelineArtifactStager",
2037
"STAGE_1_ARTIFACT_SPECS",
2138
"STAGE_1_BUILD_DATASETS",
2239
"STAGE_1_BUILD_STEP_SPECS",
40+
"SourceDatasetSchemaSummaryWriter",
41+
"TargetDatabaseSchemaSummaryWriter",
2342
"stage_1_artifact_specs",
2443
"stage_1_contract_artifact_specs",
44+
"stage_1_diagnostic_artifact_specs",
45+
"stage_1_pipeline_artifact_specs",
2546
"stage_1_script_outputs",
2647
"stage_1_step_specs",
48+
"write_stage_1_diagnostics",
2749
]

policyengine_us_data/build_datasets/artifacts.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class DatasetArtifactSpec:
2626
required_for_stage_2: bool = False
2727
yearless_alias: bool = False
2828
contract_output: bool = True
29+
pipeline_output: bool = True
30+
diagnostic_output: bool = False
31+
diagnostic_kind: str | None = None
2932
skip_when_enhanced_cps_skipped: bool = False
3033
skip_when_stage_5_skipped: bool = False
3134

@@ -53,6 +56,7 @@ class DatasetArtifactSpec:
5356
storage_path="policyengine_us_data/storage/uprating_factors.csv",
5457
script_path=_UPRATING_SCRIPT,
5558
contract_output=False,
59+
pipeline_output=False,
5660
),
5761
DatasetArtifactSpec(
5862
filename="acs_2022.h5",
@@ -120,6 +124,7 @@ class DatasetArtifactSpec:
120124
),
121125
script_path=_ENHANCED_CPS_SCRIPT,
122126
contract_output=False,
127+
pipeline_output=False,
123128
skip_when_enhanced_cps_skipped=True,
124129
),
125130
DatasetArtifactSpec(
@@ -130,6 +135,7 @@ class DatasetArtifactSpec:
130135
storage_path="calibration_log.csv",
131136
script_path=_ENHANCED_CPS_SCRIPT,
132137
contract_output=False,
138+
pipeline_output=False,
133139
skip_when_enhanced_cps_skipped=True,
134140
),
135141
DatasetArtifactSpec(
@@ -184,18 +190,62 @@ class DatasetArtifactSpec:
184190
storage_path="policyengine_us_data/storage/calibration/policy_data.db",
185191
required_for_stage_2=True,
186192
),
193+
DatasetArtifactSpec(
194+
filename="calibration_weights.npy",
195+
logical_name="calibration_weights",
196+
artifact_family="legacy_optional_weight",
197+
substage_id="1g_stage_base_datasets",
198+
storage_path="policyengine_us_data/storage/calibration_weights.npy",
199+
required=False,
200+
contract_output=False,
201+
),
187202
DatasetArtifactSpec(
188203
filename="build_log.txt",
189204
logical_name="build_log",
190205
artifact_family="log",
191206
substage_id="1g_stage_base_datasets",
207+
storage_path="build_log.txt",
192208
),
193209
DatasetArtifactSpec(
194210
filename="data_build_checkpoint_stats.json",
195211
logical_name="data_build_checkpoint_stats",
196212
artifact_family="execution_metadata",
197213
substage_id="1g_stage_base_datasets",
198214
),
215+
DatasetArtifactSpec(
216+
filename="dataset_inventory.json",
217+
logical_name="dataset_inventory",
218+
artifact_family="diagnostic",
219+
substage_id="1g_stage_base_datasets",
220+
required=False,
221+
contract_output=False,
222+
pipeline_output=False,
223+
diagnostic_output=True,
224+
diagnostic_kind="dataset_inventory",
225+
),
226+
DatasetArtifactSpec(
227+
filename="source_dataset_schema_summary.json",
228+
logical_name="source_dataset_schema_summary",
229+
artifact_family="diagnostic",
230+
substage_id="1f_source_imputation",
231+
required=False,
232+
contract_output=False,
233+
pipeline_output=False,
234+
diagnostic_output=True,
235+
diagnostic_kind="source_dataset_schema_summary",
236+
skip_when_stage_5_skipped=True,
237+
),
238+
DatasetArtifactSpec(
239+
filename="target_database_schema_summary.json",
240+
logical_name="target_database_schema_summary",
241+
artifact_family="diagnostic",
242+
substage_id="1g_stage_base_datasets",
243+
required=False,
244+
contract_output=False,
245+
pipeline_output=False,
246+
diagnostic_output=True,
247+
diagnostic_kind="target_database_schema_summary",
248+
),
199249
)
200250

201251

@@ -223,8 +273,12 @@ class DatasetArtifactSpec:
223273
"small_enhanced_cps_2024.h5",
224274
"source_imputed_stratified_extended_cps.h5",
225275
"policy_data.db",
276+
"calibration_weights.npy",
226277
"build_log.txt",
227278
"data_build_checkpoint_stats.json",
279+
"dataset_inventory.json",
280+
"source_dataset_schema_summary.json",
281+
"target_database_schema_summary.json",
228282
],
229283
validation_commands=["uv run pytest tests/unit/test_build_dataset_specs.py"],
230284
)
@@ -240,6 +294,18 @@ def stage_1_contract_artifact_specs() -> tuple[DatasetArtifactSpec, ...]:
240294
return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.contract_output)
241295

242296

297+
def stage_1_pipeline_artifact_specs() -> tuple[DatasetArtifactSpec, ...]:
298+
"""Return artifact specs staged into the run-scoped pipeline directory."""
299+
300+
return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.pipeline_output)
301+
302+
303+
def stage_1_diagnostic_artifact_specs() -> tuple[DatasetArtifactSpec, ...]:
304+
"""Return diagnostic artifact specs emitted by Stage 1 writers."""
305+
306+
return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.diagnostic_output)
307+
308+
243309
def stage_1_script_outputs() -> Mapping[str, ScriptOutput]:
244310
"""Return the checkpoint output mapping consumed by Modal data-build."""
245311

@@ -261,5 +327,7 @@ def stage_1_script_outputs() -> Mapping[str, ScriptOutput]:
261327
"ScriptOutput",
262328
"stage_1_artifact_specs",
263329
"stage_1_contract_artifact_specs",
330+
"stage_1_diagnostic_artifact_specs",
331+
"stage_1_pipeline_artifact_specs",
264332
"stage_1_script_outputs",
265333
]

0 commit comments

Comments
 (0)