diff --git a/changelog.d/1048.added b/changelog.d/1048.added new file mode 100644 index 000000000..054951215 --- /dev/null +++ b/changelog.d/1048.added @@ -0,0 +1 @@ +Added a Stage 1 dataset-build context, artifact stager, and diagnostic artifact writers for the pipeline handoff. diff --git a/modal_app/data_build.py b/modal_app/data_build.py index e629ed615..76e5f8800 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -1,5 +1,4 @@ import functools -import json import os import shutil import subprocess @@ -22,14 +21,17 @@ from modal_app.images import cpu_image as image # noqa: E402 from policyengine_us_data.__version__ import __version__ as DATA_PACKAGE_VERSION # noqa: E402 -from policyengine_us_data.build_datasets import stage_1_script_outputs # noqa: E402 +from policyengine_us_data.build_datasets import ( # noqa: E402 + DatasetBuildContext, + DatasetBuildOutputContractBuilder, + PipelineArtifactStager, + stage_1_script_outputs, + write_stage_1_diagnostics, +) from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402 from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402 from policyengine_us_data.stage_contracts import ( # noqa: E402 - DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, StageContract, - build_dataset_build_output_contract, - write_contract, ) from policyengine_us_data.utils.run_context import ( # noqa: E402 CANDIDATE_VERSION_ENV, @@ -484,13 +486,18 @@ def write_dataset_build_contract( skip_enhanced_cps: bool, skip_stage_5: bool = False, package_version: str = DATA_PACKAGE_VERSION, + branch: str = "unknown", + diagnostics: tuple = (), ) -> StageContract: """Write the Stage 1 semantic handoff contract next to copied artifacts.""" - contract = build_dataset_build_output_contract( - artifacts_dir=artifacts_dir, + context = DatasetBuildContext( run_id=run_id, + branch=branch, code_sha=code_sha, package_version=package_version, + artifacts_dir=artifacts_dir, + ) + return DatasetBuildOutputContractBuilder(context=context).write( checkpoint_stats=checkpoint_stats, started_at=started_at, completed_at=completed_at, @@ -499,12 +506,8 @@ def write_dataset_build_contract( stage_only=stage_only, skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, + diagnostics=diagnostics, ) - write_contract( - contract, - artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, - ) - return contract @app.function( @@ -529,7 +532,15 @@ def write_dataset_build_contract( status="current", stability="moving", pathways=["data_build", "orchestration"], - artifacts_out=["source_imputed_*.h5", "policy_data.db"], + artifacts_out=[ + "dataset_build_output.json", + "dataset_inventory.json", + "source_dataset_schema_summary.json", + "target_database_schema_summary.json", + "source_imputed_stratified_extended_cps_2024.h5", + "source_imputed_stratified_extended_cps.h5", + "policy_data.db", + ], validation_commands=["uv run pytest tests/unit/test_modal_data_build.py"], ) ) @@ -810,41 +821,32 @@ def build_datasets( artifacts_dir = Path(PIPELINE_MOUNT) / "artifacts" if run_id: artifacts_dir = artifacts_dir / run_id - artifacts_dir.mkdir(parents=True, exist_ok=True) - - # Copy all intermediate H5 datasets for lineage tracing - for output in SCRIPT_OUTPUTS.values(): - paths = output if isinstance(output, list) else [output] - for p in paths: - src = Path(p) - if src.suffix == ".h5" and src.exists(): - shutil.copy2(src, artifacts_dir / src.name) - print( - f" Copied {src.name} ({src.stat().st_size / 1024 / 1024:.1f} MB)" - ) - - # Yearless alias for pipeline consumers (remote_calibration_runner, local_area) - si = artifacts_dir / "source_imputed_stratified_extended_cps_2024.h5" - if si.exists(): - shutil.copy2(si, artifacts_dir / "source_imputed_stratified_extended_cps.h5") - - shutil.copy2( - "policyengine_us_data/storage/calibration/policy_data.db", - artifacts_dir / "policy_data.db", + build_context = DatasetBuildContext( + run_id=run_id, + branch=branch, + code_sha=commit, + package_version=version, + artifacts_dir=artifacts_dir, ) - cal_weights = Path("policyengine_us_data/storage/calibration_weights.npy") - if cal_weights.exists(): - shutil.copy2( - cal_weights, - artifacts_dir / "calibration_weights.npy", + stager = PipelineArtifactStager(context=build_context) + staged_paths = stager.stage_declared_artifacts( + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + ) + for staged_path in staged_paths: + print( + f" Copied {staged_path.name} " + f"({staged_path.stat().st_size / 1024 / 1024:.1f} MB)" ) - print(" Copied calibration_weights.npy") - shutil.copy2(log_path, artifacts_dir / "build_log.txt") checkpoint_snapshot = checkpoint_stats.snapshot() - with open(artifacts_dir / "data_build_checkpoint_stats.json", "w") as f: - json.dump(checkpoint_snapshot, f, indent=2, sort_keys=True) + stager.write_checkpoint_stats(checkpoint_snapshot) log_file.close() completed_at_dt = datetime.now(timezone.utc) + diagnostics = write_stage_1_diagnostics( + context=build_context, + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + ) write_dataset_build_contract( artifacts_dir=artifacts_dir, run_id=run_id, @@ -858,6 +860,8 @@ def build_datasets( skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, package_version=version, + branch=branch, + diagnostics=diagnostics, ) pipeline_volume.commit() print("Pipeline artifacts committed to shared volume") diff --git a/policyengine_us_data/build_datasets/__init__.py b/policyengine_us_data/build_datasets/__init__.py index e6b86fabf..9f69dedd5 100644 --- a/policyengine_us_data/build_datasets/__init__.py +++ b/policyengine_us_data/build_datasets/__init__.py @@ -5,23 +5,45 @@ STAGE_1_ARTIFACT_SPECS, stage_1_artifact_specs, stage_1_contract_artifact_specs, + stage_1_diagnostic_artifact_specs, + stage_1_pipeline_artifact_specs, stage_1_script_outputs, ) +from .context import DatasetBuildContext +from .contracts import DatasetBuildOutputContractBuilder +from .diagnostics import ( + ARTIFACT_SCHEMA_VERSION, + DatasetInventoryWriter, + SourceDatasetSchemaSummaryWriter, + TargetDatabaseSchemaSummaryWriter, + write_stage_1_diagnostics, +) from .specs import ( DatasetBuildStepSpec, STAGE_1_BUILD_DATASETS, STAGE_1_BUILD_STEP_SPECS, stage_1_step_specs, ) +from .staging import PipelineArtifactStager __all__ = [ + "ARTIFACT_SCHEMA_VERSION", "DatasetArtifactSpec", + "DatasetBuildContext", + "DatasetBuildOutputContractBuilder", "DatasetBuildStepSpec", + "DatasetInventoryWriter", + "PipelineArtifactStager", "STAGE_1_ARTIFACT_SPECS", "STAGE_1_BUILD_DATASETS", "STAGE_1_BUILD_STEP_SPECS", + "SourceDatasetSchemaSummaryWriter", + "TargetDatabaseSchemaSummaryWriter", "stage_1_artifact_specs", "stage_1_contract_artifact_specs", + "stage_1_diagnostic_artifact_specs", + "stage_1_pipeline_artifact_specs", "stage_1_script_outputs", "stage_1_step_specs", + "write_stage_1_diagnostics", ] diff --git a/policyengine_us_data/build_datasets/artifacts.py b/policyengine_us_data/build_datasets/artifacts.py index 70fd70afd..969b1ffe3 100644 --- a/policyengine_us_data/build_datasets/artifacts.py +++ b/policyengine_us_data/build_datasets/artifacts.py @@ -26,6 +26,9 @@ class DatasetArtifactSpec: required_for_stage_2: bool = False yearless_alias: bool = False contract_output: bool = True + pipeline_output: bool = True + diagnostic_output: bool = False + diagnostic_kind: str | None = None skip_when_enhanced_cps_skipped: bool = False skip_when_stage_5_skipped: bool = False @@ -53,6 +56,7 @@ class DatasetArtifactSpec: storage_path="policyengine_us_data/storage/uprating_factors.csv", script_path=_UPRATING_SCRIPT, contract_output=False, + pipeline_output=False, ), DatasetArtifactSpec( filename="acs_2022.h5", @@ -120,6 +124,7 @@ class DatasetArtifactSpec: ), script_path=_ENHANCED_CPS_SCRIPT, contract_output=False, + pipeline_output=False, skip_when_enhanced_cps_skipped=True, ), DatasetArtifactSpec( @@ -130,6 +135,7 @@ class DatasetArtifactSpec: storage_path="calibration_log.csv", script_path=_ENHANCED_CPS_SCRIPT, contract_output=False, + pipeline_output=False, skip_when_enhanced_cps_skipped=True, ), DatasetArtifactSpec( @@ -184,11 +190,21 @@ class DatasetArtifactSpec: storage_path="policyengine_us_data/storage/calibration/policy_data.db", required_for_stage_2=True, ), + DatasetArtifactSpec( + filename="calibration_weights.npy", + logical_name="calibration_weights", + artifact_family="legacy_optional_weight", + substage_id="1g_stage_base_datasets", + storage_path="policyengine_us_data/storage/calibration_weights.npy", + required=False, + contract_output=False, + ), DatasetArtifactSpec( filename="build_log.txt", logical_name="build_log", artifact_family="log", substage_id="1g_stage_base_datasets", + storage_path="build_log.txt", ), DatasetArtifactSpec( filename="data_build_checkpoint_stats.json", @@ -196,6 +212,40 @@ class DatasetArtifactSpec: artifact_family="execution_metadata", substage_id="1g_stage_base_datasets", ), + DatasetArtifactSpec( + filename="dataset_inventory.json", + logical_name="dataset_inventory", + artifact_family="diagnostic", + substage_id="1g_stage_base_datasets", + required=False, + contract_output=False, + pipeline_output=False, + diagnostic_output=True, + diagnostic_kind="dataset_inventory", + ), + DatasetArtifactSpec( + filename="source_dataset_schema_summary.json", + logical_name="source_dataset_schema_summary", + artifact_family="diagnostic", + substage_id="1f_source_imputation", + required=False, + contract_output=False, + pipeline_output=False, + diagnostic_output=True, + diagnostic_kind="source_dataset_schema_summary", + skip_when_stage_5_skipped=True, + ), + DatasetArtifactSpec( + filename="target_database_schema_summary.json", + logical_name="target_database_schema_summary", + artifact_family="diagnostic", + substage_id="1g_stage_base_datasets", + required=False, + contract_output=False, + pipeline_output=False, + diagnostic_output=True, + diagnostic_kind="target_database_schema_summary", + ), ) @@ -223,8 +273,12 @@ class DatasetArtifactSpec: "small_enhanced_cps_2024.h5", "source_imputed_stratified_extended_cps.h5", "policy_data.db", + "calibration_weights.npy", "build_log.txt", "data_build_checkpoint_stats.json", + "dataset_inventory.json", + "source_dataset_schema_summary.json", + "target_database_schema_summary.json", ], validation_commands=["uv run pytest tests/unit/test_build_dataset_specs.py"], ) @@ -240,6 +294,18 @@ def stage_1_contract_artifact_specs() -> tuple[DatasetArtifactSpec, ...]: return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.contract_output) +def stage_1_pipeline_artifact_specs() -> tuple[DatasetArtifactSpec, ...]: + """Return artifact specs staged into the run-scoped pipeline directory.""" + + return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.pipeline_output) + + +def stage_1_diagnostic_artifact_specs() -> tuple[DatasetArtifactSpec, ...]: + """Return diagnostic artifact specs emitted by Stage 1 writers.""" + + return tuple(spec for spec in STAGE_1_ARTIFACT_SPECS if spec.diagnostic_output) + + def stage_1_script_outputs() -> Mapping[str, ScriptOutput]: """Return the checkpoint output mapping consumed by Modal data-build.""" @@ -261,5 +327,7 @@ def stage_1_script_outputs() -> Mapping[str, ScriptOutput]: "ScriptOutput", "stage_1_artifact_specs", "stage_1_contract_artifact_specs", + "stage_1_diagnostic_artifact_specs", + "stage_1_pipeline_artifact_specs", "stage_1_script_outputs", ] diff --git a/policyengine_us_data/build_datasets/context.py b/policyengine_us_data/build_datasets/context.py new file mode 100644 index 000000000..16e57825d --- /dev/null +++ b/policyengine_us_data/build_datasets/context.py @@ -0,0 +1,66 @@ +"""Run context for the Stage 1 dataset-build handoff.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from .specs import STAGE_1_BUILD_DATASETS + + +@dataclass(frozen=True, kw_only=True) +class DatasetBuildContext: + """Identity and filesystem context for one Stage 1 dataset-build run.""" + + run_id: str + branch: str + code_sha: str + package_version: str + artifacts_dir: Path + storage_dir: Path = Path("policyengine_us_data/storage") + work_dir: Path = Path(".") + stage_id: str = STAGE_1_BUILD_DATASETS + + def __post_init__(self) -> None: + if not self.run_id: + raise ValueError("run_id is required") + if not self.branch: + raise ValueError("branch is required") + if not self.code_sha: + raise ValueError("code_sha is required") + if not self.package_version: + raise ValueError("package_version is required") + object.__setattr__(self, "artifacts_dir", Path(self.artifacts_dir)) + object.__setattr__(self, "storage_dir", Path(self.storage_dir)) + object.__setattr__(self, "work_dir", Path(self.work_dir)) + + def source_path(self, storage_path: str) -> Path: + """Resolve a declared storage or working-directory source path.""" + + path = Path(storage_path) + if path.is_absolute(): + return path + storage_prefix = Path("policyengine_us_data/storage") + try: + return self.storage_dir / path.relative_to(storage_prefix) + except ValueError: + return self.work_dir / path + + def artifact_path(self, filename: str) -> Path: + """Return the run-scoped destination path for a staged artifact.""" + + return self.artifacts_dir / filename + + def identity(self) -> dict[str, str]: + """Return stable identity fields for Stage 1 diagnostic payloads.""" + + return { + "run_id": self.run_id, + "stage_id": self.stage_id, + "branch": self.branch, + "code_sha": self.code_sha, + "package_version": self.package_version, + } + + +__all__ = ["DatasetBuildContext"] diff --git a/policyengine_us_data/build_datasets/contracts.py b/policyengine_us_data/build_datasets/contracts.py new file mode 100644 index 000000000..13a627daf --- /dev/null +++ b/policyengine_us_data/build_datasets/contracts.py @@ -0,0 +1,68 @@ +"""Contract builder facade for Stage 1 dataset-build outputs.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + +from .context import DatasetBuildContext + + +@dataclass(frozen=True, kw_only=True) +class DatasetBuildOutputContractBuilder: + """Build and persist the Stage 1 dataset-build handoff contract.""" + + context: DatasetBuildContext + + def build( + self, + *, + checkpoint_stats: Mapping[str, int], + started_at: str | None, + completed_at: str, + duration_s: float | None, + upload_requested: bool, + stage_only: bool, + skip_enhanced_cps: bool, + skip_stage_5: bool = False, + diagnostics: Sequence[object] = (), + ): + """Build the Stage 1 handoff contract from staged artifacts.""" + + from policyengine_us_data.stage_contracts import ( + build_dataset_build_output_contract, + ) + + return build_dataset_build_output_contract( + artifacts_dir=self.context.artifacts_dir, + run_id=self.context.run_id, + code_sha=self.context.code_sha, + package_version=self.context.package_version, + checkpoint_stats=checkpoint_stats, + started_at=started_at, + completed_at=completed_at, + duration_s=duration_s, + upload_requested=upload_requested, + stage_only=stage_only, + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + diagnostics=tuple(diagnostics), + ) + + def write(self, **kwargs): + """Build and write the Stage 1 handoff contract next to artifacts.""" + + from policyengine_us_data.stage_contracts import ( + DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + write_contract, + ) + + contract = self.build(**kwargs) + write_contract( + contract, + self.context.artifacts_dir / DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + ) + return contract + + +__all__ = ["DatasetBuildOutputContractBuilder"] diff --git a/policyengine_us_data/build_datasets/diagnostics.py b/policyengine_us_data/build_datasets/diagnostics.py new file mode 100644 index 000000000..f42ead19c --- /dev/null +++ b/policyengine_us_data/build_datasets/diagnostics.py @@ -0,0 +1,391 @@ +"""Diagnostic artifact writers for Stage 1 dataset-build outputs.""" + +from __future__ import annotations + +import json +import sqlite3 +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .artifacts import ( + DatasetArtifactSpec, + stage_1_diagnostic_artifact_specs, + stage_1_pipeline_artifact_specs, +) +from .context import DatasetBuildContext +from policyengine_us_data.utils.step_manifest import sha256_file + + +ARTIFACT_SCHEMA_VERSION = "1" + + +def _json_default(value: Any) -> Any: + if isinstance(value, Path): + return str(value) + raise TypeError(f"Object is not JSON serializable: {type(value).__name__}") + + +def _write_json(path: Path, payload: Mapping[str, Any]) -> None: + path.write_text( + json.dumps( + payload, + default=_json_default, + indent=2, + sort_keys=True, + ) + + "\n" + ) + + +def _media_type_for_path(path: Path) -> str: + suffix = path.suffix.lower() + if suffix == ".h5": + return "application/x-hdf5" + if suffix == ".db": + return "application/vnd.sqlite3" + if suffix == ".json": + return "application/json" + if suffix == ".npy": + return "application/x-numpy-array" + if suffix == ".txt": + return "text/plain" + return "application/octet-stream" + + +def _artifact_ref_for_path( + *, + logical_name: str, + path: Path, + metadata: Mapping[str, Any], +): + from policyengine_us_data.stage_contracts import ArtifactRef + + return ArtifactRef( + logical_name=logical_name, + uri=path.resolve().as_uri(), + sha256=f"sha256:{sha256_file(path)}", + size_bytes=path.stat().st_size, + media_type=_media_type_for_path(path), + metadata=metadata, + ) + + +def _diagnostic_ref_for_path( + *, + spec: DatasetArtifactSpec, + path: Path, + summary: Mapping[str, Any], +): + from policyengine_us_data.stage_contracts import DiagnosticRef + + return DiagnosticRef( + name=spec.logical_name, + kind=spec.diagnostic_kind or spec.artifact_family, + artifact=_artifact_ref_for_path( + logical_name=spec.logical_name, + path=path, + metadata={ + "artifact_family": spec.artifact_family, + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + "substage_id": spec.substage_id, + }, + ), + summary=summary, + severity="info", + ) + + +def _diagnostic_spec(logical_name: str) -> DatasetArtifactSpec: + for spec in stage_1_diagnostic_artifact_specs(): + if spec.logical_name == logical_name: + return spec + raise KeyError(f"Unknown Stage 1 diagnostic spec: {logical_name}") + + +def _cheap_h5_summary(path: Path) -> dict[str, Any]: + import h5py + + datasets: list[dict[str, Any]] = [] + entities: dict[str, dict[str, Any]] = {} + + with h5py.File(path, "r") as h5_file: + + def visit(name: str, obj: Any) -> None: + if not isinstance(obj, h5py.Dataset): + return + parts = name.split("/") + entity = parts[0] if parts else "" + period = parts[-1] if parts[-1].isdigit() else None + variable = parts[-2] if period is not None and len(parts) > 1 else parts[-1] + row_count = int(obj.shape[0]) if obj.shape else None + datasets.append( + { + "path": name, + "entity": entity, + "variable": variable, + "period": period, + "dtype": str(obj.dtype), + "shape": list(obj.shape), + "row_count": row_count, + } + ) + entity_summary = entities.setdefault( + entity, + { + "dataset_count": 0, + "variables": set(), + "periods": set(), + "row_counts": {}, + }, + ) + entity_summary["dataset_count"] += 1 + entity_summary["variables"].add(variable) + if period is not None: + entity_summary["periods"].add(period) + if row_count is not None: + entity_summary["row_counts"][name] = row_count + + h5_file.visititems(visit) + + return { + "datasets": datasets, + "entities": { + entity: { + "dataset_count": summary["dataset_count"], + "variables": sorted(summary["variables"]), + "periods": sorted(summary["periods"]), + "row_counts": summary["row_counts"], + } + for entity, summary in sorted(entities.items()) + }, + } + + +def _sqlite_summary(path: Path) -> dict[str, Any]: + tables = [] + with sqlite3.connect(f"file:{path}?mode=ro", uri=True) as conn: + conn.row_factory = sqlite3.Row + table_names = [ + row["name"] + for row in conn.execute( + """ + SELECT name + FROM sqlite_master + WHERE type = 'table' AND name NOT LIKE 'sqlite_%' + ORDER BY name + """ + ) + ] + checksum_material = [] + for table_name in table_names: + quoted_table_name = _quote_sql_identifier(table_name) + columns = [ + { + "name": row["name"], + "type": row["type"], + "notnull": int(row["notnull"]), + "pk": int(row["pk"]), + } + for row in conn.execute(f"PRAGMA table_info({quoted_table_name})") + ] + row_count = conn.execute( + f"SELECT COUNT(*) AS row_count FROM {quoted_table_name}" + ).fetchone()["row_count"] + table_summary = { + "name": table_name, + "columns": columns, + "row_count": int(row_count), + } + tables.append(table_summary) + checksum_material.append(table_summary) + + digest_payload = json.dumps( + checksum_material, + sort_keys=True, + separators=(",", ":"), + ).encode() + import hashlib + + return { + "tables": tables, + "known_target_tables": [ + table["name"] + for table in tables + if table["name"] in {"targets", "strata", "stratum_constraints"} + ], + "schema_checksum": hashlib.sha256(digest_payload).hexdigest(), + } + + +def _quote_sql_identifier(identifier: str) -> str: + return '"' + identifier.replace('"', '""') + '"' + + +@dataclass(frozen=True, kw_only=True) +class DatasetInventoryWriter: + """Write a compact inventory of Stage 1 artifacts staged for a run.""" + + context: DatasetBuildContext + + def write( + self, + *, + skip_enhanced_cps: bool = False, + skip_stage_5: bool = False, + ): + spec = _diagnostic_spec("dataset_inventory") + artifacts = [] + seen_logical_names: set[str] = set() + for artifact_spec in stage_1_pipeline_artifact_specs(): + if skip_enhanced_cps and artifact_spec.skip_when_enhanced_cps_skipped: + continue + if skip_stage_5 and artifact_spec.skip_when_stage_5_skipped: + continue + path = self.context.artifact_path(artifact_spec.filename) + if not path.exists(): + if artifact_spec.required: + raise FileNotFoundError(f"Missing staged artifact: {path}") + continue + if artifact_spec.logical_name in seen_logical_names: + raise ValueError( + f"Duplicate Stage 1 artifact: {artifact_spec.logical_name}" + ) + seen_logical_names.add(artifact_spec.logical_name) + artifacts.append(_inventory_entry(artifact_spec, path)) + + payload = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + **self.context.identity(), + "artifacts": artifacts, + } + path = self.context.artifact_path(spec.filename) + _write_json(path, payload) + return _diagnostic_ref_for_path( + spec=spec, + path=path, + summary={"artifact_count": len(artifacts)}, + ) + + +@dataclass(frozen=True, kw_only=True) +class SourceDatasetSchemaSummaryWriter: + """Write a metadata-only schema summary for the source-imputed H5 handoff.""" + + context: DatasetBuildContext + + def write(self): + spec = _diagnostic_spec("source_dataset_schema_summary") + source_path = self.context.artifact_path( + "source_imputed_stratified_extended_cps.h5" + ) + if not source_path.exists(): + raise FileNotFoundError(f"Missing source dataset artifact: {source_path}") + h5_summary = _cheap_h5_summary(source_path) + payload = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + **self.context.identity(), + "logical_name": "source_imputed_stratified_extended_cps", + "path": source_path.name, + **h5_summary, + } + path = self.context.artifact_path(spec.filename) + _write_json(path, payload) + return _diagnostic_ref_for_path( + spec=spec, + path=path, + summary={ + "entity_count": len(h5_summary["entities"]), + "dataset_count": len(h5_summary["datasets"]), + }, + ) + + +@dataclass(frozen=True, kw_only=True) +class TargetDatabaseSchemaSummaryWriter: + """Write a schema and row-count summary for the Stage 1 target database.""" + + context: DatasetBuildContext + + def write(self): + spec = _diagnostic_spec("target_database_schema_summary") + db_path = self.context.artifact_path("policy_data.db") + if not db_path.exists(): + raise FileNotFoundError(f"Missing target database artifact: {db_path}") + db_summary = _sqlite_summary(db_path) + payload = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + **self.context.identity(), + "logical_name": "policy_data_db", + "path": db_path.name, + **db_summary, + } + path = self.context.artifact_path(spec.filename) + _write_json(path, payload) + return _diagnostic_ref_for_path( + spec=spec, + path=path, + summary={ + "table_count": len(db_summary["tables"]), + "known_target_tables": db_summary["known_target_tables"], + "schema_checksum": db_summary["schema_checksum"], + }, + ) + + +def _inventory_entry(spec: DatasetArtifactSpec, path: Path) -> dict[str, Any]: + entry: dict[str, Any] = { + "artifact_schema_version": ARTIFACT_SCHEMA_VERSION, + "logical_name": spec.logical_name, + "artifact_family": spec.artifact_family, + "substage_id": spec.substage_id, + "path": path.name, + "sha256": f"sha256:{sha256_file(path)}", + "size_bytes": path.stat().st_size, + "media_type": _media_type_for_path(path), + } + if spec.period is not None: + entry["period"] = spec.period + if path.suffix == ".h5": + entry["row_counts"] = { + dataset["path"]: dataset["row_count"] + for dataset in _cheap_h5_summary(path)["datasets"] + if dataset["row_count"] is not None + } + elif path.suffix == ".db": + db_summary = _sqlite_summary(path) + entry["row_counts"] = { + table["name"]: table["row_count"] for table in db_summary["tables"] + } + entry["schema_checksum"] = db_summary["schema_checksum"] + return entry + + +def write_stage_1_diagnostics( + *, + context: DatasetBuildContext, + skip_enhanced_cps: bool = False, + skip_stage_5: bool = False, +) -> tuple[Any, ...]: + """Write Stage 1 diagnostic artifacts and return their contract refs.""" + + refs = [ + DatasetInventoryWriter(context=context).write( + skip_enhanced_cps=skip_enhanced_cps, + skip_stage_5=skip_stage_5, + ), + TargetDatabaseSchemaSummaryWriter(context=context).write(), + ] + if not skip_stage_5: + refs.insert(1, SourceDatasetSchemaSummaryWriter(context=context).write()) + return tuple(refs) + + +__all__ = [ + "ARTIFACT_SCHEMA_VERSION", + "DatasetInventoryWriter", + "SourceDatasetSchemaSummaryWriter", + "TargetDatabaseSchemaSummaryWriter", + "write_stage_1_diagnostics", +] diff --git a/policyengine_us_data/build_datasets/staging.py b/policyengine_us_data/build_datasets/staging.py new file mode 100644 index 000000000..4956d4e60 --- /dev/null +++ b/policyengine_us_data/build_datasets/staging.py @@ -0,0 +1,83 @@ +"""Artifact staging helpers for Stage 1 dataset-build outputs.""" + +from __future__ import annotations + +import json +import shutil +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path + +from .artifacts import ( + DatasetArtifactSpec, + stage_1_pipeline_artifact_specs, +) +from .context import DatasetBuildContext + + +@dataclass(frozen=True, kw_only=True) +class PipelineArtifactStager: + """Stage declared Stage 1 artifacts into a run-scoped pipeline directory.""" + + context: DatasetBuildContext + + def stage_declared_artifacts( + self, + *, + skip_enhanced_cps: bool = False, + skip_stage_5: bool = False, + ) -> tuple[Path, ...]: + self.context.artifacts_dir.mkdir(parents=True, exist_ok=True) + staged: list[Path] = [] + missing_required: list[str] = [] + + for spec in stage_1_pipeline_artifact_specs(): + if skip_enhanced_cps and spec.skip_when_enhanced_cps_skipped: + continue + if skip_stage_5 and spec.skip_when_stage_5_skipped: + continue + if spec.yearless_alias: + alias = self._stage_yearless_alias(spec) + if alias is not None: + staged.append(alias) + continue + if spec.storage_path is None: + continue + + source = self.context.source_path(spec.storage_path) + destination = self.context.artifact_path(spec.filename) + if not source.exists(): + if spec.required: + missing_required.append(spec.filename) + continue + shutil.copy2(source, destination) + staged.append(destination) + + if missing_required: + raise FileNotFoundError( + "Missing Stage 1 pipeline artifact(s): " + + ", ".join(sorted(missing_required)) + ) + return tuple(staged) + + def write_checkpoint_stats(self, checkpoint_stats: Mapping[str, int]) -> Path: + """Write checkpoint reuse metadata as an explicit Stage 1 artifact.""" + + path = self.context.artifact_path("data_build_checkpoint_stats.json") + path.write_text( + json.dumps(dict(checkpoint_stats), indent=2, sort_keys=True) + "\n" + ) + return path + + def _stage_yearless_alias(self, spec: DatasetArtifactSpec) -> Path | None: + source = self.context.artifact_path( + "source_imputed_stratified_extended_cps_2024.h5" + ) + if not source.exists(): + return None + destination = self.context.artifact_path(spec.filename) + shutil.copy2(source, destination) + return destination + + +__all__ = ["PipelineArtifactStager"] diff --git a/policyengine_us_data/stage_contracts/dataset_build.py b/policyengine_us_data/stage_contracts/dataset_build.py index 3b3f7f4b2..53843f953 100644 --- a/policyengine_us_data/stage_contracts/dataset_build.py +++ b/policyengine_us_data/stage_contracts/dataset_build.py @@ -14,6 +14,7 @@ from .artifacts import ArtifactRef from .contracts import StageContract +from .diagnostics import DiagnosticRef from .execution import ExecutionRecord, ReuseSummary from .fingerprints import fingerprint_material from .stages import STAGE_1_BUILD_DATASETS, contract_type_for_stage @@ -37,6 +38,7 @@ def build_dataset_build_output_contract( stage_only: bool = False, skip_enhanced_cps: bool = False, skip_stage_5: bool = False, + diagnostics: tuple[DiagnosticRef, ...] = (), ) -> StageContract: """Build the Stage 1 handoff contract from copied pipeline artifacts.""" @@ -80,11 +82,13 @@ def build_dataset_build_output_contract( skip_enhanced_cps=skip_enhanced_cps, skip_stage_5=skip_stage_5, ), + diagnostics=diagnostics, execution=execution, metadata={ "artifact_count": len(outputs), "artifact_directory": str(artifacts_dir), "contract_file": DATASET_BUILD_OUTPUT_CONTRACT_FILENAME, + "diagnostic_count": len(diagnostics), }, ) diff --git a/pyproject.toml b/pyproject.toml index a9224268a..907d463d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "policyengine-us==1.702.0", + "policyengine-us==1.702.1", # policyengine-core 3.26.1 is the current 3.26.x runtime and includes the fix for # PolicyEngine/policyengine-core#482 (user-set ETERNITY inputs lost # after _invalidate_all_caches) and is required by policyengine-us 1.682.1+. diff --git a/tests/unit/test_build_dataset_specs.py b/tests/unit/test_build_dataset_specs.py index 070874cae..1bf6b189f 100644 --- a/tests/unit/test_build_dataset_specs.py +++ b/tests/unit/test_build_dataset_specs.py @@ -4,6 +4,8 @@ STAGE_1_BUILD_DATASETS, STAGE_1_BUILD_STEP_SPECS, stage_1_contract_artifact_specs, + stage_1_diagnostic_artifact_specs, + stage_1_pipeline_artifact_specs, stage_1_script_outputs, ) from policyengine_us_data.stage_contracts import ( @@ -99,6 +101,16 @@ def test_stage_1_contract_outputs_are_explicit_subset(): assert all(spec.contract_output for spec in contract_specs) +def test_stage_1_diagnostics_are_not_storage_staging_specs(): + diagnostic_specs = stage_1_diagnostic_artifact_specs() + + assert diagnostic_specs == tuple( + spec for spec in STAGE_1_ARTIFACT_SPECS if spec.diagnostic_output + ) + assert all(not spec.pipeline_output for spec in diagnostic_specs) + assert not any(spec.diagnostic_output for spec in stage_1_pipeline_artifact_specs()) + + def test_step_manifest_stage_1_substeps_match_dataset_build_specs(): assert tuple(substep.id for substep in BUILD_DATASETS.substeps) == tuple( spec.id for spec in STAGE_1_BUILD_STEP_SPECS @@ -131,6 +143,7 @@ def test_stage_1_skip_flags_identify_expected_artifacts(): } <= enhanced_cps_skipped assert { "small_enhanced_cps_2024.h5", + "source_dataset_schema_summary.json", "source_imputed_stratified_extended_cps_2024.h5", "source_imputed_stratified_extended_cps.h5", } == stage_5_skipped diff --git a/tests/unit/test_build_dataset_staging.py b/tests/unit/test_build_dataset_staging.py new file mode 100644 index 000000000..dceec2175 --- /dev/null +++ b/tests/unit/test_build_dataset_staging.py @@ -0,0 +1,217 @@ +import json +import sqlite3 +from pathlib import Path + +import h5py +import pytest + +from policyengine_us_data.build_datasets import ( + DatasetBuildContext, + DatasetBuildOutputContractBuilder, + DatasetInventoryWriter, + PipelineArtifactStager, + SourceDatasetSchemaSummaryWriter, + TargetDatabaseSchemaSummaryWriter, + stage_1_pipeline_artifact_specs, + write_stage_1_diagnostics, +) + + +def _context(tmp_path: Path) -> DatasetBuildContext: + return DatasetBuildContext( + run_id="run-123", + branch="main", + code_sha="abc123", + package_version="1.98.2", + artifacts_dir=tmp_path / "artifacts", + storage_dir=tmp_path / "storage", + work_dir=tmp_path / "work", + ) + + +def _write_h5(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(path, "w") as h5_file: + person = h5_file.create_group("person") + person.create_dataset("age/2024", data=[1, 2, 3]) + person.create_dataset("is_disabled", data=[0, 1, 0]) + household = h5_file.create_group("household") + household.create_dataset("weight/2024", data=[10.0, 20.0]) + + +def _write_sqlite(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with sqlite3.connect(path) as conn: + conn.execute("CREATE TABLE targets (id INTEGER PRIMARY KEY, value REAL)") + conn.execute("CREATE TABLE notes (id INTEGER PRIMARY KEY, label TEXT)") + conn.execute("INSERT INTO targets (value) VALUES (1.5), (2.5)") + conn.execute("INSERT INTO notes (label) VALUES ('a')") + + +def _write_text(path: Path, payload: str = "x") -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(payload) + + +def _write_required_storage_artifacts( + context: DatasetBuildContext, + *, + include_enhanced_cps: bool = True, + include_stage_5: bool = True, + include_optional_weights: bool = False, +) -> None: + for spec in stage_1_pipeline_artifact_specs(): + if spec.yearless_alias or spec.storage_path is None: + continue + if not include_enhanced_cps and spec.skip_when_enhanced_cps_skipped: + continue + if not include_stage_5 and spec.skip_when_stage_5_skipped: + continue + if not spec.required and not include_optional_weights: + continue + path = context.source_path(spec.storage_path) + if path.suffix == ".h5": + _write_h5(path) + elif path.suffix == ".db": + _write_sqlite(path) + else: + _write_text(path, spec.logical_name) + + +def test_stager_copies_only_declared_artifacts(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context, include_optional_weights=True) + extra = context.storage_dir / "untracked.h5" + _write_h5(extra) + + staged = PipelineArtifactStager(context=context).stage_declared_artifacts() + + staged_names = {path.name for path in staged} + assert "untracked.h5" not in staged_names + assert "acs_2022.h5" in staged_names + assert "policy_data.db" in staged_names + assert "calibration_weights.npy" in staged_names + + +def test_stager_creates_yearless_source_imputed_alias(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context) + + PipelineArtifactStager(context=context).stage_declared_artifacts() + + assert ( + context.artifacts_dir / "source_imputed_stratified_extended_cps_2024.h5" + ).exists() + alias = context.artifacts_dir / "source_imputed_stratified_extended_cps.h5" + assert alias.exists() + with h5py.File(alias) as h5_file: + assert list(h5_file["person"]["age"]["2024"]) == [1, 2, 3] + + +def test_stager_fails_on_missing_required_artifact(tmp_path): + context = _context(tmp_path) + + with pytest.raises(FileNotFoundError, match="acs_2022.h5"): + PipelineArtifactStager(context=context).stage_declared_artifacts() + + +def test_stager_respects_skip_flags_for_optional_ecps_paths(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts( + context, + include_enhanced_cps=False, + include_stage_5=False, + ) + + staged = PipelineArtifactStager(context=context).stage_declared_artifacts( + skip_enhanced_cps=True, + skip_stage_5=True, + ) + + staged_names = {path.name for path in staged} + assert "enhanced_cps_2024.h5" not in staged_names + assert "small_enhanced_cps_2024.h5" not in staged_names + assert "source_imputed_stratified_extended_cps_2024.h5" not in staged_names + assert "source_imputed_stratified_extended_cps.h5" not in staged_names + + +def test_dataset_inventory_contains_each_staged_artifact_once(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context, include_optional_weights=True) + stager = PipelineArtifactStager(context=context) + stager.stage_declared_artifacts() + stager.write_checkpoint_stats({"expected_outputs": 3}) + + diagnostic = DatasetInventoryWriter(context=context).write() + + inventory_path = context.artifacts_dir / "dataset_inventory.json" + payload = json.loads(inventory_path.read_text()) + logical_names = [artifact["logical_name"] for artifact in payload["artifacts"]] + assert len(logical_names) == len(set(logical_names)) + assert "policy_data_db" in logical_names + assert "data_build_checkpoint_stats" in logical_names + assert diagnostic.artifact.logical_name == "dataset_inventory" + assert diagnostic.summary["artifact_count"] == len(logical_names) + + +def test_source_dataset_schema_summary_is_metadata_only(tmp_path): + context = _context(tmp_path) + context.artifacts_dir.mkdir(parents=True) + _write_h5(context.artifacts_dir / "source_imputed_stratified_extended_cps.h5") + + diagnostic = SourceDatasetSchemaSummaryWriter(context=context).write() + + payload = json.loads( + (context.artifacts_dir / "source_dataset_schema_summary.json").read_text() + ) + assert payload["logical_name"] == "source_imputed_stratified_extended_cps" + assert payload["entities"]["person"]["variables"] == ["age", "is_disabled"] + assert payload["entities"]["household"]["row_counts"] == { + "household/weight/2024": 2 + } + assert diagnostic.summary == {"dataset_count": 3, "entity_count": 2} + + +def test_target_database_summary_reports_tables_and_row_counts(tmp_path): + context = _context(tmp_path) + context.artifacts_dir.mkdir(parents=True) + _write_sqlite(context.artifacts_dir / "policy_data.db") + + diagnostic = TargetDatabaseSchemaSummaryWriter(context=context).write() + + payload = json.loads( + (context.artifacts_dir / "target_database_schema_summary.json").read_text() + ) + assert [table["name"] for table in payload["tables"]] == ["notes", "targets"] + row_counts = {table["name"]: table["row_count"] for table in payload["tables"]} + assert row_counts == {"notes": 1, "targets": 2} + assert payload["known_target_tables"] == ["targets"] + assert diagnostic.summary["table_count"] == 2 + assert diagnostic.summary["known_target_tables"] == ("targets",) + + +def test_contract_builder_records_stage_1_diagnostics(tmp_path): + context = _context(tmp_path) + _write_required_storage_artifacts(context) + stager = PipelineArtifactStager(context=context) + stager.stage_declared_artifacts() + stager.write_checkpoint_stats({"expected_outputs": 3}) + diagnostics = write_stage_1_diagnostics(context=context) + + contract = DatasetBuildOutputContractBuilder(context=context).build( + checkpoint_stats={"expected_outputs": 3}, + started_at="2026-05-08T12:00:00Z", + completed_at="2026-05-08T12:01:00Z", + duration_s=60.0, + upload_requested=True, + stage_only=True, + skip_enhanced_cps=False, + diagnostics=diagnostics, + ) + + assert {diagnostic.name for diagnostic in contract.diagnostics} == { + "dataset_inventory", + "source_dataset_schema_summary", + "target_database_schema_summary", + } + assert contract.metadata["diagnostic_count"] == 3 diff --git a/tests/unit/test_dataset_build_stage_contract.py b/tests/unit/test_dataset_build_stage_contract.py index ae36424c3..3e69cb691 100644 --- a/tests/unit/test_dataset_build_stage_contract.py +++ b/tests/unit/test_dataset_build_stage_contract.py @@ -1,6 +1,7 @@ from pathlib import Path from policyengine_us_data.stage_contracts import ( + DiagnosticRef, StageContract, contract_from_json, contract_to_json, @@ -205,3 +206,26 @@ def test_dataset_build_contract_fingerprint_excludes_run_id(tmp_path): def test_dataset_build_contract_filename_is_stable(): assert DATASET_BUILD_OUTPUT_CONTRACT_FILENAME == "dataset_build_output.json" + + +def test_dataset_build_contract_records_diagnostic_refs(tmp_path): + _write_artifacts(tmp_path) + diagnostic = DiagnosticRef( + name="dataset_inventory", + kind="dataset_inventory", + summary={"artifact_count": 13}, + ) + + contract = build_dataset_build_output_contract( + artifacts_dir=tmp_path, + run_id="run-a", + code_sha="abc123", + package_version="1.98.2", + checkpoint_stats={"expected_outputs": 4}, + started_at="2026-05-08T12:00:00Z", + completed_at="2026-05-08T12:01:00Z", + diagnostics=(diagnostic,), + ) + + assert contract.diagnostics == (diagnostic,) + assert contract.metadata["diagnostic_count"] == 1 diff --git a/tests/unit/test_modal_data_build.py b/tests/unit/test_modal_data_build.py index eed66aec7..28abfbb99 100644 --- a/tests/unit/test_modal_data_build.py +++ b/tests/unit/test_modal_data_build.py @@ -309,6 +309,7 @@ def test_write_dataset_build_contract_writes_stage_1_handoff(tmp_path): upload_requested=False, stage_only=True, skip_enhanced_cps=True, + branch="stage-1", ) contract_path = tmp_path / "dataset_build_output.json" diff --git a/uv.lock b/uv.lock index 67329f64c..81d80b028 100644 --- a/uv.lock +++ b/uv.lock @@ -2122,7 +2122,7 @@ wheels = [ [[package]] name = "policyengine-us" -version = "1.702.0" +version = "1.702.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "microdf-python" }, @@ -2132,9 +2132,9 @@ dependencies = [ { name = "tables" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/43/7e/d3095e6dde387cb56eb2dd0543cdc0b0f7670446d3b6ea45468165d60d1f/policyengine_us-1.702.0.tar.gz", hash = "sha256:689526d444c98681d517247d5308e795e02f24c65423295232ab347e61cac981", size = 9876039, upload-time = "2026-05-21T14:56:36.133Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4d/3e/6000ddb6cd51bb5d832089cf2904b88773e431e358fef4f4bd736d5aea0e/policyengine_us-1.702.1.tar.gz", hash = "sha256:b3782233a8e3d6c5eca48f329cad87e46319c170eacb64836f1966e58e5f95b6", size = 9884003, upload-time = "2026-05-21T16:42:13.543Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/1d/67cde50bf6401c5c3ab95ff8f4036876422fa6fc72481425f3f3c7eb3177/policyengine_us-1.702.0-py3-none-any.whl", hash = "sha256:83d787337760587dbfcfe6bc2ae59afb53d2baa5827cb535776ff7147561a72f", size = 10649615, upload-time = "2026-05-21T14:56:33.349Z" }, + { url = "https://files.pythonhosted.org/packages/79/d7/95bbe3549a5f932ff91a53f84a78a70497b2e11e395dfd1fdd1d76ba9a71/policyengine_us-1.702.1-py3-none-any.whl", hash = "sha256:f6029ae7319219f1e36c805f778dd8594742e89867b0c5fc07459b6ee18b487e", size = 10673068, upload-time = "2026-05-21T16:42:09.985Z" }, ] [[package]] @@ -2204,7 +2204,7 @@ requires-dist = [ { name = "pandas", specifier = ">=2.3.1" }, { name = "pip-system-certs", specifier = ">=3.0" }, { name = "policyengine-core", specifier = ">=3.26.1,<3.27" }, - { name = "policyengine-us", specifier = "==1.702.0" }, + { name = "policyengine-us", specifier = "==1.702.1" }, { name = "requests", specifier = ">=2.25.0" }, { name = "samplics", marker = "extra == 'calibration'" }, { name = "scipy", specifier = ">=1.15.3" },