Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/1048.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a Stage 1 dataset-build context, artifact stager, and diagnostic artifact writers for the pipeline handoff.
90 changes: 47 additions & 43 deletions modal_app/data_build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import json
import os
import shutil
import subprocess
Expand All @@ -22,14 +21,17 @@

from modal_app.images import cpu_image as image # noqa: E402
from policyengine_us_data.__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402
from policyengine_us_data.build_datasets import stage_1_script_outputs # noqa: E402
from policyengine_us_data.build_datasets import ( # noqa: E402
DatasetBuildContext,
DatasetBuildOutputContractBuilder,
PipelineArtifactStager,
stage_1_script_outputs,
write_stage_1_diagnostics,
)
from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402
from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402
from policyengine_us_data.stage_contracts import ( # noqa: E402
DATASET_BUILD_OUTPUT_CONTRACT_FILENAME,
StageContract,
build_dataset_build_output_contract,
write_contract,
)
from policyengine_us_data.utils.run_context import ( # noqa: E402
CANDIDATE_VERSION_ENV,
Expand Down Expand Up @@ -484,13 +486,18 @@ def write_dataset_build_contract(
skip_enhanced_cps: bool,
skip_stage_5: bool = False,
package_version: str = DATA_PACKAGE_VERSION,
branch: str = "unknown",
diagnostics: tuple = (),
) -> StageContract:
"""Write the Stage 1 semantic handoff contract next to copied artifacts."""
contract = build_dataset_build_output_contract(
artifacts_dir=artifacts_dir,
context = DatasetBuildContext(
run_id=run_id,
branch=branch,
code_sha=code_sha,
package_version=package_version,
artifacts_dir=artifacts_dir,
)
return DatasetBuildOutputContractBuilder(context=context).write(
checkpoint_stats=checkpoint_stats,
started_at=started_at,
completed_at=completed_at,
Expand All @@ -499,12 +506,8 @@ def write_dataset_build_contract(
stage_only=stage_only,
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
diagnostics=diagnostics,
)
write_contract(
contract,
artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME,
)
return contract


@app.function(
Expand All @@ -529,7 +532,15 @@ def write_dataset_build_contract(
status="current",
stability="moving",
pathways=["data_build", "orchestration"],
artifacts_out=["source_imputed_*.h5", "policy_data.db"],
artifacts_out=[
"dataset_build_output.json",
"dataset_inventory.json",
"source_dataset_schema_summary.json",
"target_database_schema_summary.json",
"source_imputed_stratified_extended_cps_2024.h5",
"source_imputed_stratified_extended_cps.h5",
"policy_data.db",
],
validation_commands=["uv run pytest tests/unit/test_modal_data_build.py"],
)
)
Expand Down Expand Up @@ -810,41 +821,32 @@ def build_datasets(
artifacts_dir = Path(PIPELINE_MOUNT) / "artifacts"
if run_id:
artifacts_dir = artifacts_dir / run_id
artifacts_dir.mkdir(parents=True, exist_ok=True)

# Copy all intermediate H5 datasets for lineage tracing
for output in SCRIPT_OUTPUTS.values():
paths = output if isinstance(output, list) else [output]
for p in paths:
src = Path(p)
if src.suffix == ".h5" and src.exists():
shutil.copy2(src, artifacts_dir / src.name)
print(
f" Copied {src.name} ({src.stat().st_size / 1024 / 1024:.1f} MB)"
)

# Yearless alias for pipeline consumers (remote_calibration_runner, local_area)
si = artifacts_dir / "source_imputed_stratified_extended_cps_2024.h5"
if si.exists():
shutil.copy2(si, artifacts_dir / "source_imputed_stratified_extended_cps.h5")

shutil.copy2(
"policyengine_us_data/storage/calibration/policy_data.db",
artifacts_dir / "policy_data.db",
build_context = DatasetBuildContext(
run_id=run_id,
branch=branch,
code_sha=commit,
package_version=version,
artifacts_dir=artifacts_dir,
)
cal_weights = Path("policyengine_us_data/storage/calibration_weights.npy")
if cal_weights.exists():
shutil.copy2(
cal_weights,
artifacts_dir / "calibration_weights.npy",
stager = PipelineArtifactStager(context=build_context)
staged_paths = stager.stage_declared_artifacts(
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
)
for staged_path in staged_paths:
print(
f" Copied {staged_path.name} "
f"({staged_path.stat().st_size / 1024 / 1024:.1f} MB)"
)
print(" Copied calibration_weights.npy")
shutil.copy2(log_path, artifacts_dir / "build_log.txt")
checkpoint_snapshot = checkpoint_stats.snapshot()
with open(artifacts_dir / "data_build_checkpoint_stats.json", "w") as f:
json.dump(checkpoint_snapshot, f, indent=2, sort_keys=True)
stager.write_checkpoint_stats(checkpoint_snapshot)
log_file.close()
completed_at_dt = datetime.now(timezone.utc)
diagnostics = write_stage_1_diagnostics(
context=build_context,
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
)
write_dataset_build_contract(
artifacts_dir=artifacts_dir,
run_id=run_id,
Expand All @@ -858,6 +860,8 @@ def build_datasets(
skip_enhanced_cps=skip_enhanced_cps,
skip_stage_5=skip_stage_5,
package_version=version,
branch=branch,
diagnostics=diagnostics,
)
pipeline_volume.commit()
print("Pipeline artifacts committed to shared volume")
Expand Down
22 changes: 22 additions & 0 deletions policyengine_us_data/build_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,45 @@
STAGE_1_ARTIFACT_SPECS,
stage_1_artifact_specs,
stage_1_contract_artifact_specs,
stage_1_diagnostic_artifact_specs,
stage_1_pipeline_artifact_specs,
stage_1_script_outputs,
)
from .context import DatasetBuildContext
from .contracts import DatasetBuildOutputContractBuilder
from .diagnostics import (
ARTIFACT_SCHEMA_VERSION,
DatasetInventoryWriter,
SourceDatasetSchemaSummaryWriter,
TargetDatabaseSchemaSummaryWriter,
write_stage_1_diagnostics,
)
from .specs import (
DatasetBuildStepSpec,
STAGE_1_BUILD_DATASETS,
STAGE_1_BUILD_STEP_SPECS,
stage_1_step_specs,
)
from .staging import PipelineArtifactStager

__all__ = [
"ARTIFACT_SCHEMA_VERSION",
"DatasetArtifactSpec",
"DatasetBuildContext",
"DatasetBuildOutputContractBuilder",
"DatasetBuildStepSpec",
"DatasetInventoryWriter",
"PipelineArtifactStager",
"STAGE_1_ARTIFACT_SPECS",
"STAGE_1_BUILD_DATASETS",
"STAGE_1_BUILD_STEP_SPECS",
"SourceDatasetSchemaSummaryWriter",
"TargetDatabaseSchemaSummaryWriter",
"stage_1_artifact_specs",
"stage_1_contract_artifact_specs",
"stage_1_diagnostic_artifact_specs",
"stage_1_pipeline_artifact_specs",
"stage_1_script_outputs",
"stage_1_step_specs",
"write_stage_1_diagnostics",
]
68 changes: 68 additions & 0 deletions policyengine_us_data/build_datasets/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class DatasetArtifactSpec:
required_for_stage_2: bool = False
yearless_alias: bool = False
contract_output: bool = True
pipeline_output: bool = True
diagnostic_output: bool = False
diagnostic_kind: str | None = None
skip_when_enhanced_cps_skipped: bool = False
skip_when_stage_5_skipped: bool = False

Expand Down Expand Up @@ -53,6 +56,7 @@ class DatasetArtifactSpec:
storage_path="policyengine_us_data/storage/uprating_factors.csv",
script_path=_UPRATING_SCRIPT,
contract_output=False,
pipeline_output=False,
),
DatasetArtifactSpec(
filename="acs_2022.h5",
Expand Down Expand Up @@ -120,6 +124,7 @@ class DatasetArtifactSpec:
),
script_path=_ENHANCED_CPS_SCRIPT,
contract_output=False,
pipeline_output=False,
skip_when_enhanced_cps_skipped=True,
),
DatasetArtifactSpec(
Expand All @@ -130,6 +135,7 @@ class DatasetArtifactSpec:
storage_path="calibration_log.csv",
script_path=_ENHANCED_CPS_SCRIPT,
contract_output=False,
pipeline_output=False,
skip_when_enhanced_cps_skipped=True,
),
DatasetArtifactSpec(
Expand Down Expand Up @@ -184,18 +190,62 @@ class DatasetArtifactSpec:
storage_path="policyengine_us_data/storage/calibration/policy_data.db",
required_for_stage_2=True,
),
DatasetArtifactSpec(
filename="calibration_weights.npy",
logical_name="calibration_weights",
artifact_family="legacy_optional_weight",
substage_id="1g_stage_base_datasets",
storage_path="policyengine_us_data/storage/calibration_weights.npy",
required=False,
contract_output=False,
),
DatasetArtifactSpec(
filename="build_log.txt",
logical_name="build_log",
artifact_family="log",
substage_id="1g_stage_base_datasets",
storage_path="build_log.txt",
),
DatasetArtifactSpec(
filename="data_build_checkpoint_stats.json",
logical_name="data_build_checkpoint_stats",
artifact_family="execution_metadata",
substage_id="1g_stage_base_datasets",
),
DatasetArtifactSpec(
filename="dataset_inventory.json",
logical_name="dataset_inventory",
artifact_family="diagnostic",
substage_id="1g_stage_base_datasets",
required=False,
contract_output=False,
pipeline_output=False,
diagnostic_output=True,
diagnostic_kind="dataset_inventory",
),
DatasetArtifactSpec(
filename="source_dataset_schema_summary.json",
logical_name="source_dataset_schema_summary",
artifact_family="diagnostic",
substage_id="1f_source_imputation",
required=False,
contract_output=False,
pipeline_output=False,
diagnostic_output=True,
diagnostic_kind="source_dataset_schema_summary",
skip_when_stage_5_skipped=True,
),
DatasetArtifactSpec(
filename="target_database_schema_summary.json",
logical_name="target_database_schema_summary",
artifact_family="diagnostic",
substage_id="1g_stage_base_datasets",
required=False,
contract_output=False,
pipeline_output=False,
diagnostic_output=True,
diagnostic_kind="target_database_schema_summary",
),
)


Expand Down Expand Up @@ -223,8 +273,12 @@ class DatasetArtifactSpec:
"small_enhanced_cps_2024.h5",
"source_imputed_stratified_extended_cps.h5",
"policy_data.db",
"calibration_weights.npy",
"build_log.txt",
"data_build_checkpoint_stats.json",
"dataset_inventory.json",
"source_dataset_schema_summary.json",
"target_database_schema_summary.json",
],
validation_commands=["uv run pytest tests/unit/test_build_dataset_specs.py"],
)
Expand All @@ -240,6 +294,18 @@ def stage_1_contract_artifact_specs() -> tuple[DatasetArtifactSpec, ...]:
return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.contract_output)


def stage_1_pipeline_artifact_specs() -> tuple[DatasetArtifactSpec, ...]:
"""Return artifact specs staged into the run-scoped pipeline directory."""

return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.pipeline_output)


def stage_1_diagnostic_artifact_specs() -> tuple[DatasetArtifactSpec, ...]:
"""Return diagnostic artifact specs emitted by Stage 1 writers."""

return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.diagnostic_output)


def stage_1_script_outputs() -> Mapping[str, ScriptOutput]:
"""Return the checkpoint output mapping consumed by Modal data-build."""

Expand All @@ -261,5 +327,7 @@ def stage_1_script_outputs() -> Mapping[str, ScriptOutput]:
"ScriptOutput",
"stage_1_artifact_specs",
"stage_1_contract_artifact_specs",
"stage_1_diagnostic_artifact_specs",
"stage_1_pipeline_artifact_specs",
"stage_1_script_outputs",
]
Loading
Loading