From 037623f1ef87f984a0d5967b536db5f26db82c22 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 19 May 2026 21:45:14 +0200 Subject: [PATCH 1/4] Add scoped Stage 3 fit result bundles --- changelog.d/1045.changed | 1 + docs/engineering/stages/fit_weights.md | 9 + modal_app/pipeline.py | 83 ++++----- modal_app/remote_calibration_runner.py | 15 +- policyengine_us_data/fit_weights/__init__.py | 12 ++ policyengine_us_data/fit_weights/bundles.py | 182 +++++++++++++++++++ tests/unit/fit_weights/test_bundles.py | 130 +++++++++++++ tests/unit/fit_weights/test_pipeline_docs.py | 4 + tests/unit/test_pipeline_source_contracts.py | 23 ++- 9 files changed, 404 insertions(+), 55 deletions(-) create mode 100644 changelog.d/1045.changed create mode 100644 policyengine_us_data/fit_weights/bundles.py create mode 100644 tests/unit/fit_weights/test_bundles.py diff --git a/changelog.d/1045.changed b/changelog.d/1045.changed new file mode 100644 index 000000000..2d3014a32 --- /dev/null +++ b/changelog.d/1045.changed @@ -0,0 +1 @@ +Added scoped Stage 3 fitted-weight input and output bundles. diff --git a/docs/engineering/stages/fit_weights.md b/docs/engineering/stages/fit_weights.md index 795e74160..21424b455 100644 --- a/docs/engineering/stages/fit_weights.md +++ b/docs/engineering/stages/fit_weights.md @@ -8,6 +8,9 @@ builds. The public identity boundary lives in `policyengine_us_data.fit_weights` step manifests for reuse decisions. - `ScopedFitArtifacts` defines the artifact filenames written by the Modal fit step and consumed by downstream H5 builders. +- `FittedWeightsInputBundle`, `FitResultBytes`, and + `FittedWeightsOutputBundle` keep Stage 3 package inputs and remote result + bytes typed before they become files. The current artifact names remain behavior-compatible: @@ -21,3 +24,9 @@ The current artifact names remain behavior-compatible: When changing Stage 3 fitting parameters, artifact names, or scope behavior, update the central specs first and then adapt Modal callers to consume those specs. Do not add parallel filename constants in orchestration code. + +When changing remote result handling, keep `_collect_outputs(...)` as the +compatibility adapter for subprocess stdout markers and convert its dictionary +shape into `FittedWeightsOutputBundle` before writing artifacts to the pipeline +volume. Fit step manifests should attach diagnostics from the matching output +scope rather than all run diagnostics. diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index ab1bf9d3f..acd7bb26b 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -39,7 +39,6 @@ import time import traceback from datetime import datetime, timezone -from io import BytesIO from pathlib import Path import modal @@ -110,6 +109,8 @@ from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402 from policyengine_us_data.fit_weights import ( # noqa: E402 FitScope, + FittedWeightsInputBundle, + FittedWeightsOutputBundle, NATIONAL_FIT_LAMBDA_L0 as _NATIONAL_FIT_LAMBDA_L0, fit_artifacts_for_scope, fitted_weights_spec_for_scope, @@ -279,13 +280,14 @@ def archive_diagnostics( vol: modal.Volume, prefix: str = "", scope: FitScope | str | None = None, -) -> None: +) -> list[ArtifactReference]: """Archive calibration diagnostics to the run directory.""" diag_dir = Path(RUNS_DIR) / run_id / "diagnostics" diag_dir.mkdir(parents=True, exist_ok=True) scope = scope or (FitScope.NATIONAL if prefix == "national_" else FitScope.REGIONAL) file_map = fit_artifacts_for_scope(scope).diagnostic_result_filenames() + written_paths: list[Path] = [] for key, filename in file_map.items(): data = result_bytes.get(key) @@ -293,9 +295,15 @@ def archive_diagnostics( path = diag_dir / filename with open(path, "wb") as f: f.write(data) + written_paths.append(path) print(f" Archived {filename} ({len(data):,} bytes)") vol.commit() + return collect_artifacts( + written_paths, + role="diagnostic", + missing_ok=True, + ) # ── Include other Modal apps ───────────────────────────────────── @@ -1315,12 +1323,11 @@ def run_pipeline( print(f" Completed in {completed_package_manifest.duration_s}s") # ── Step 3: Fit weights (parallel) ── - fit_inputs = _artifact_identities( - { - "calibration_package": _artifacts_dir(run_id) - / "calibration_package.pkl", - } + regional_fit_input = FittedWeightsInputBundle( + scope=FitScope.REGIONAL, + calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl", ) + fit_inputs = _artifact_identities(regional_fit_input.artifact_identity_paths()) regional_fit_spec = fitted_weights_spec_for_scope(FitScope.REGIONAL) national_fit_spec = fitted_weights_spec_for_scope(FitScope.NATIONAL) regional_fit_artifacts = fit_artifacts_for_scope(FitScope.REGIONAL) @@ -1425,34 +1432,26 @@ def run_pipeline( # Collect regional results print(" Waiting for regional fit...") regional_result = regional_handle.get() + regional_output = FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.REGIONAL, + result_bytes=regional_result, + run_id=run_id, + ) print(" Regional fit complete. Writing to volume...") # Write regional results to pipeline volume (run-scoped) artifacts_rel = f"artifacts/{run_id}" if run_id else "artifacts" with pipeline_volume.batch_upload(force=True) as batch: - batch.put_file( - BytesIO(regional_result["weights"]), - f"{artifacts_rel}/{regional_fit_artifacts.weights.filename}", - ) - if regional_result.get("geography"): - batch.put_file( - BytesIO(regional_result["geography"]), - f"{artifacts_rel}/{regional_fit_artifacts.geography.filename}", - ) - if regional_result.get("config"): - batch.put_file( - BytesIO(regional_result["config"]), - f"{artifacts_rel}/{regional_fit_artifacts.run_config.filename}", - ) + regional_output.write_artifacts(batch, artifacts_rel) - archive_diagnostics( + regional_diagnostics = archive_diagnostics( run_id, - regional_result, + regional_output.diagnostic_result_bytes(), pipeline_volume, - scope=FitScope.REGIONAL, + scope=regional_output.scope, ) regional_outputs = collect_artifacts( - regional_fit_artifacts.artifact_paths(_artifacts_dir(run_id)), + regional_output.artifact_paths(_artifacts_dir(run_id)), missing_ok=True, ) regional_fit_reuse_measurement = ReuseMeasurement( @@ -1462,7 +1461,7 @@ def run_pipeline( _complete_step_manifest( regional_fit_manifest, outputs=regional_outputs, - diagnostics=_collect_diagnostics(run_id), + diagnostics=regional_diagnostics, reuse_decision="computed", reuse_measurement=regional_fit_reuse_measurement, vol=pipeline_volume, @@ -1473,38 +1472,30 @@ def run_pipeline( if national_handle is not None: print(" Waiting for national fit...") national_result = national_handle.get() + national_output = FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.NATIONAL, + result_bytes=national_result, + run_id=run_id, + ) print(" National fit complete. Writing to volume...") with pipeline_volume.batch_upload(force=True) as batch: - batch.put_file( - BytesIO(national_result["weights"]), - f"{artifacts_rel}/{national_fit_artifacts.weights.filename}", - ) - if national_result.get("geography"): - batch.put_file( - BytesIO(national_result["geography"]), - f"{artifacts_rel}/{national_fit_artifacts.geography.filename}", - ) - if national_result.get("config"): - batch.put_file( - BytesIO(national_result["config"]), - f"{artifacts_rel}/{national_fit_artifacts.run_config.filename}", - ) - - archive_diagnostics( + national_output.write_artifacts(batch, artifacts_rel) + + national_diagnostics = archive_diagnostics( run_id, - national_result, + national_output.diagnostic_result_bytes(), pipeline_volume, - scope=FitScope.NATIONAL, + scope=national_output.scope, ) national_outputs = collect_artifacts( - national_fit_artifacts.artifact_paths(_artifacts_dir(run_id)), + national_output.artifact_paths(_artifacts_dir(run_id)), missing_ok=True, ) _complete_step_manifest( national_fit_manifest, outputs=national_outputs, - diagnostics=_collect_diagnostics(run_id), + diagnostics=national_diagnostics, reuse_decision="computed", reuse_measurement=ReuseMeasurement( expected_outputs=len(national_outputs), diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index ccd96a658..28289be4f 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -13,6 +13,7 @@ from modal_app.images import gpu_image as image # noqa: E402 from policyengine_us_data.fit_weights import ( # noqa: E402 + FitResultBytes, FitScope, NATIONAL_FIT_LAMBDA_L0, fit_artifacts_for_scope, @@ -139,13 +140,13 @@ def _collect_outputs(cal_lines): with open(config_path, "rb") as f: config_bytes = f.read() - return { - "weights": weights_bytes, - "geography": geography_bytes, - "log": log_bytes, - "cal_log": cal_log_bytes, - "config": config_bytes, - } + return FitResultBytes( + weights=weights_bytes, + geography=geography_bytes, + diagnostics=log_bytes, + epoch_log=cal_log_bytes, + run_config=config_bytes, + ).to_result_dict() def _fit_output_filenames( diff --git a/policyengine_us_data/fit_weights/__init__.py b/policyengine_us_data/fit_weights/__init__.py index 2bd50c5db..fb135a374 100644 --- a/policyengine_us_data/fit_weights/__init__.py +++ b/policyengine_us_data/fit_weights/__init__.py @@ -7,6 +7,13 @@ ScopedFitArtifacts, fit_artifacts_for_scope, ) +from policyengine_us_data.fit_weights.bundles import ( + FitResultBytes, + FitWeightsBuildContext, + FittedWeightsInputBundle, + FittedWeightsOutputBundle, + MissingFitWeightsOutputError, +) from policyengine_us_data.fit_weights.specs import ( FIT_BETA, FIT_LOG_FREQ, @@ -35,8 +42,13 @@ "FitArtifactRole", "FitArtifactSpec", "FitHyperparameters", + "FitResultBytes", "FitScope", + "FitWeightsBuildContext", + "FittedWeightsInputBundle", + "FittedWeightsOutputBundle", "FittedWeightsSpec", + "MissingFitWeightsOutputError", "ScopedFitArtifacts", "fit_artifacts_for_scope", "fitted_weights_spec_for_scope", diff --git a/policyengine_us_data/fit_weights/bundles.py b/policyengine_us_data/fit_weights/bundles.py new file mode 100644 index 000000000..5aead4868 --- /dev/null +++ b/policyengine_us_data/fit_weights/bundles.py @@ -0,0 +1,182 @@ +"""Scoped Stage 3 fitted-weight input and output bundles.""" + +from __future__ import annotations + +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import Mapping + +from policyengine_us_data.fit_weights.artifacts import ( + ScopedFitArtifacts, + fit_artifacts_for_scope, +) +from policyengine_us_data.fit_weights.specs import FitScope +from policyengine_us_data.pipeline_metadata import pipeline_node +from policyengine_us_data.pipeline_schema import PipelineNode + + +class MissingFitWeightsOutputError(ValueError): + """Raised when remote fit bytes omit required weights.""" + + +@dataclass(frozen=True) +class FitWeightsBuildContext: + """Run-scoped filesystem context for Stage 3 fitted-weight artifacts.""" + + run_id: str + artifacts_root: Path + diagnostics_root: Path + + +@dataclass(frozen=True) +class FittedWeightsInputBundle: + """Scoped Stage 3 input paths consumed before fitting starts.""" + + scope: FitScope | str + calibration_package_path: Path + + def __post_init__(self) -> None: + object.__setattr__(self, "scope", FitScope.parse(self.scope)) + object.__setattr__( + self, + "calibration_package_path", + Path(self.calibration_package_path), + ) + + def artifact_identity_paths(self) -> dict[str, Path]: + """Return paths used for Stage 3 input identity calculation.""" + + return {"calibration_package": self.calibration_package_path} + + +@dataclass(frozen=True) +class FitResultBytes: + """Compatibility transport model for current remote fit result bytes.""" + + weights: bytes + geography: bytes | None = None + diagnostics: bytes | None = None + epoch_log: bytes | None = None + run_config: bytes | None = None + + @classmethod + def from_mapping(cls, result_bytes: Mapping[str, bytes | None]) -> "FitResultBytes": + """Build transport bytes from the current remote result dictionary.""" + + weights = result_bytes.get("weights") + if weights is None: + raise MissingFitWeightsOutputError( + "Fitted-weight result is missing required weights bytes." + ) + return cls( + weights=weights, + geography=result_bytes.get("geography"), + diagnostics=result_bytes.get("log"), + epoch_log=result_bytes.get("cal_log"), + run_config=result_bytes.get("config"), + ) + + def to_result_dict(self) -> dict[str, bytes | None]: + """Return the legacy result dictionary shape used by Modal adapters.""" + + return { + "weights": self.weights, + "geography": self.geography, + "log": self.diagnostics, + "cal_log": self.epoch_log, + "config": self.run_config, + } + + def bytes_for_result_key(self, result_key: str | None) -> bytes | None: + """Return bytes for an artifact spec result key.""" + + return self.to_result_dict().get(result_key or "") + + +@pipeline_node( + PipelineNode( + id="fitted_weights_output_bundle", + label="Fitted Weights Output Bundle", + node_type="library", + description="Scoped Stage 3 result bytes before artifact file writes.", + source_file="policyengine_us_data/fit_weights/bundles.py", + status="current", + stability="moving", + pathways=["fit_weights", "artifact_identity"], + artifacts_in=["remote fit result bytes"], + artifacts_out=["scoped fitted-weight artifact writes"], + validation_commands=["uv run pytest tests/unit/fit_weights/test_bundles.py"], + ) +) +@dataclass(frozen=True) +class FittedWeightsOutputBundle: + """Scoped output bundle created before Stage 3 bytes become files.""" + + scope: FitScope | str + result: FitResultBytes + artifacts: ScopedFitArtifacts + run_id: str = "" + + def __post_init__(self) -> None: + scope = FitScope.parse(self.scope) + object.__setattr__(self, "scope", scope) + if self.artifacts.scope != scope: + raise ValueError( + "Output bundle scope does not match artifact catalog: " + f"{scope.value} != {self.artifacts.scope.value}" + ) + + @classmethod + def from_result_bytes( + cls, + *, + scope: FitScope | str, + result_bytes: Mapping[str, bytes | None], + run_id: str = "", + ) -> "FittedWeightsOutputBundle": + """Build a scoped bundle from the current remote result dictionary.""" + + scope = FitScope.parse(scope) + return cls( + scope=scope, + result=FitResultBytes.from_mapping(result_bytes), + artifacts=fit_artifacts_for_scope(scope), + run_id=run_id, + ) + + def write_artifacts(self, batch, artifacts_rel: str) -> list[str]: + """Write present primary artifacts to a Modal batch upload object.""" + + written: list[str] = [] + for artifact in self.artifacts.artifact_specs(): + data = self.result.bytes_for_result_key(artifact.result_key) + if data is None: + continue + destination = f"{artifacts_rel}/{artifact.filename}" + batch.put_file(BytesIO(data), destination) + written.append(destination) + return written + + def artifact_paths(self, artifacts_root: str | Path) -> list[Path]: + """Return expected primary artifact paths under a pipeline artifact root.""" + + return self.artifacts.artifact_paths(artifacts_root) + + def diagnostic_result_bytes(self) -> dict[str, bytes | None]: + """Return only diagnostics belonging to this output scope.""" + + return { + artifact.result_key: self.result.bytes_for_result_key(artifact.result_key) + for artifact in self.artifacts.diagnostic_specs() + if artifact.result_key is not None + } + + +__all__ = [ + "FitResultBytes", + "FitWeightsBuildContext", + "FittedWeightsInputBundle", + "FittedWeightsOutputBundle", + "MissingFitWeightsOutputError", +] diff --git a/tests/unit/fit_weights/test_bundles.py b/tests/unit/fit_weights/test_bundles.py new file mode 100644 index 000000000..cf6289d57 --- /dev/null +++ b/tests/unit/fit_weights/test_bundles.py @@ -0,0 +1,130 @@ +from pathlib import Path + +import pytest + +from policyengine_us_data.fit_weights import ( + FitScope, + FittedWeightsInputBundle, + FittedWeightsOutputBundle, + MissingFitWeightsOutputError, +) + + +class FakeBatch: + def __init__(self) -> None: + self.files: dict[str, bytes] = {} + + def put_file(self, file_obj, destination: str) -> None: + self.files[destination] = file_obj.read() + + +def test_input_bundle_exposes_calibration_package_identity_path() -> None: + bundle = FittedWeightsInputBundle( + scope="regional", + calibration_package_path=Path( + "/pipeline/artifacts/run/calibration_package.pkl" + ), + ) + + assert bundle.scope == FitScope.REGIONAL + assert bundle.artifact_identity_paths() == { + "calibration_package": Path("/pipeline/artifacts/run/calibration_package.pkl") + } + + +def test_regional_output_bundle_writes_expected_paths() -> None: + bundle = FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.REGIONAL, + result_bytes={ + "weights": b"weights", + "geography": b"geo", + "config": b"config", + "log": b"log", + "cal_log": b"epoch", + }, + run_id="run-1", + ) + batch = FakeBatch() + + written = bundle.write_artifacts(batch, "artifacts/run-1") + + assert written == [ + "artifacts/run-1/calibration_weights.npy", + "artifacts/run-1/geography_assignment.npz", + "artifacts/run-1/unified_run_config.json", + ] + assert batch.files["artifacts/run-1/calibration_weights.npy"] == b"weights" + assert bundle.artifact_paths("/pipeline/artifacts/run-1") == [ + Path("/pipeline/artifacts/run-1/calibration_weights.npy"), + Path("/pipeline/artifacts/run-1/geography_assignment.npz"), + Path("/pipeline/artifacts/run-1/unified_run_config.json"), + ] + + +def test_national_output_bundle_writes_expected_paths() -> None: + bundle = FittedWeightsOutputBundle.from_result_bytes( + scope="national", + result_bytes={ + "weights": b"weights", + "geography": b"geo", + "config": b"config", + }, + ) + batch = FakeBatch() + + written = bundle.write_artifacts(batch, "artifacts/run-1") + + assert written == [ + "artifacts/run-1/national_calibration_weights.npy", + "artifacts/run-1/national_geography_assignment.npz", + "artifacts/run-1/national_unified_run_config.json", + ] + assert batch.files["artifacts/run-1/national_calibration_weights.npy"] == b"weights" + + +def test_missing_optional_epoch_log_is_allowed() -> None: + bundle = FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.REGIONAL, + result_bytes={ + "weights": b"weights", + "log": b"log", + }, + ) + + assert bundle.diagnostic_result_bytes() == { + "log": b"log", + "cal_log": None, + "config": None, + } + + +def test_missing_weights_is_a_hard_failure() -> None: + with pytest.raises(MissingFitWeightsOutputError, match="weights"): + FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.REGIONAL, + result_bytes={"geography": b"geo"}, + ) + + +def test_diagnostics_are_scoped_to_the_output_bundle() -> None: + regional = FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.REGIONAL, + result_bytes={ + "weights": b"weights", + "log": b"regional-log", + "cal_log": b"regional-epoch", + }, + ) + national = FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.NATIONAL, + result_bytes={ + "weights": b"weights", + "log": b"national-log", + "cal_log": b"national-epoch", + }, + ) + + assert regional.artifacts.diagnostics.filename == "unified_diagnostics.csv" + assert national.artifacts.diagnostics.filename == "national_unified_diagnostics.csv" + assert regional.diagnostic_result_bytes()["log"] == b"regional-log" + assert national.diagnostic_result_bytes()["log"] == b"national-log" diff --git a/tests/unit/fit_weights/test_pipeline_docs.py b/tests/unit/fit_weights/test_pipeline_docs.py index 3b3a26266..ec1b423da 100644 --- a/tests/unit/fit_weights/test_pipeline_docs.py +++ b/tests/unit/fit_weights/test_pipeline_docs.py @@ -18,8 +18,12 @@ def test_fit_weights_identity_nodes_are_in_generated_pipeline_docs() -> None: assert "fitted_weights_spec" in decorated assert "fitted_weights_artifacts" in decorated + assert "fitted_weights_output_bundle" in decorated assert "fit_weights" in decorated["fitted_weights_spec"].metadata["pathways"] assert "fit_weights" in decorated["fitted_weights_artifacts"].metadata["pathways"] + assert ( + "fit_weights" in decorated["fitted_weights_output_bundle"].metadata["pathways"] + ) def test_stage_3_pipeline_map_labels_match_scoped_artifacts() -> None: diff --git a/tests/unit/test_pipeline_source_contracts.py b/tests/unit/test_pipeline_source_contracts.py index 95f426890..189ff9fd4 100644 --- a/tests/unit/test_pipeline_source_contracts.py +++ b/tests/unit/test_pipeline_source_contracts.py @@ -125,11 +125,30 @@ def test_run_pipeline_uses_stage_3_fit_specs_for_reuse_and_paths() -> None: assert "national_fit_spec.manifest_parameters(" in source assert "regional_fit_spec.runtime_kwargs()" in source assert "national_fit_spec.runtime_kwargs()" in source - assert "regional_fit_artifacts.artifact_paths(_artifacts_dir(run_id))" in source - assert "national_fit_artifacts.artifact_paths(_artifacts_dir(run_id))" in source + assert "regional_output.artifact_paths(_artifacts_dir(run_id))" in source + assert "national_output.artifact_paths(_artifacts_dir(run_id))" in source assert "diagnostic_result_filenames()" in archive_source +def test_run_pipeline_converts_fit_results_to_scoped_output_bundles() -> None: + source_text = PIPELINE_SOURCE.read_text() + tree = ast.parse(source_text) + run_pipeline = _function_def(tree, "run_pipeline") + archive_diagnostics = _function_def(tree, "archive_diagnostics") + source = ast.get_source_segment(source_text, run_pipeline) + archive_source = ast.get_source_segment(source_text, archive_diagnostics) + + assert "FittedWeightsInputBundle(" in source + assert "FittedWeightsOutputBundle.from_result_bytes(" in source + assert "regional_output.write_artifacts(batch, artifacts_rel)" in source + assert "national_output.write_artifacts(batch, artifacts_rel)" in source + assert "regional_output.diagnostic_result_bytes()" in source + assert "national_output.diagnostic_result_bytes()" in source + assert "diagnostics=regional_diagnostics" in source + assert "diagnostics=national_diagnostics" in source + assert 'role="diagnostic"' in archive_source + + def test_local_area_consumes_centralized_stage_3_artifact_specs() -> None: source = Path("modal_app/local_area.py").read_text() From 78f04a7e7d088b100b402c0ea216407cfd27baab Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 20 May 2026 18:20:36 +0200 Subject: [PATCH 2/4] Require scoped fit output artifacts --- policyengine_us_data/fit_weights/bundles.py | 8 ++++- tests/unit/fit_weights/test_bundles.py | 34 ++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/policyengine_us_data/fit_weights/bundles.py b/policyengine_us_data/fit_weights/bundles.py index 5aead4868..46a2df7e9 100644 --- a/policyengine_us_data/fit_weights/bundles.py +++ b/policyengine_us_data/fit_weights/bundles.py @@ -17,7 +17,7 @@ class MissingFitWeightsOutputError(ValueError): - """Raised when remote fit bytes omit required weights.""" + """Raised when remote fit bytes omit required fitted-weight artifacts.""" @dataclass(frozen=True) @@ -152,6 +152,12 @@ def write_artifacts(self, batch, artifacts_rel: str) -> list[str]: for artifact in self.artifacts.artifact_specs(): data = self.result.bytes_for_result_key(artifact.result_key) if data is None: + if artifact.required: + raise MissingFitWeightsOutputError( + "Fitted-weight result is missing required " + f"{self.scope.value} {artifact.role.value} bytes " + f"for {artifact.filename}." + ) continue destination = f"{artifacts_rel}/{artifact.filename}" batch.put_file(BytesIO(data), destination) diff --git a/tests/unit/fit_weights/test_bundles.py b/tests/unit/fit_weights/test_bundles.py index cf6289d57..e008e9834 100644 --- a/tests/unit/fit_weights/test_bundles.py +++ b/tests/unit/fit_weights/test_bundles.py @@ -87,6 +87,8 @@ def test_missing_optional_epoch_log_is_allowed() -> None: scope=FitScope.REGIONAL, result_bytes={ "weights": b"weights", + "geography": b"geo", + "config": b"config", "log": b"log", }, ) @@ -94,7 +96,7 @@ def test_missing_optional_epoch_log_is_allowed() -> None: assert bundle.diagnostic_result_bytes() == { "log": b"log", "cal_log": None, - "config": None, + "config": b"config", } @@ -106,11 +108,39 @@ def test_missing_weights_is_a_hard_failure() -> None: ) +@pytest.mark.parametrize( + ("missing_key", "expected_role"), + [ + ("geography", "geography"), + ("config", "run_config"), + ], +) +def test_missing_required_primary_artifacts_fail_before_writes( + missing_key: str, + expected_role: str, +) -> None: + result_bytes = { + "weights": b"weights", + "geography": b"geo", + "config": b"config", + } + result_bytes.pop(missing_key) + bundle = FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.REGIONAL, + result_bytes=result_bytes, + ) + + with pytest.raises(MissingFitWeightsOutputError, match=expected_role): + bundle.write_artifacts(FakeBatch(), "artifacts/run-1") + + def test_diagnostics_are_scoped_to_the_output_bundle() -> None: regional = FittedWeightsOutputBundle.from_result_bytes( scope=FitScope.REGIONAL, result_bytes={ "weights": b"weights", + "geography": b"regional-geo", + "config": b"regional-config", "log": b"regional-log", "cal_log": b"regional-epoch", }, @@ -119,6 +149,8 @@ def test_diagnostics_are_scoped_to_the_output_bundle() -> None: scope=FitScope.NATIONAL, result_bytes={ "weights": b"weights", + "geography": b"national-geo", + "config": b"national-config", "log": b"national-log", "cal_log": b"national-epoch", }, From 0d8fca541b0179ff010a80be4df15f69e0af23c3 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 20 May 2026 22:03:22 +0200 Subject: [PATCH 3/4] Move fit weight test setup into fixtures --- tests/unit/fit_weights/conftest.py | 84 ++++++++++++ tests/unit/fit_weights/test_bundles.py | 130 +++++++------------ tests/unit/fit_weights/test_pipeline_docs.py | 17 +-- 3 files changed, 135 insertions(+), 96 deletions(-) create mode 100644 tests/unit/fit_weights/conftest.py diff --git a/tests/unit/fit_weights/conftest.py b/tests/unit/fit_weights/conftest.py new file mode 100644 index 000000000..ca9565d13 --- /dev/null +++ b/tests/unit/fit_weights/conftest.py @@ -0,0 +1,84 @@ +from collections.abc import Callable +from pathlib import Path + +import pytest +import yaml + +from policyengine_us_data.fit_weights import ( + FitScope, + FittedWeightsOutputBundle, +) + + +class FakeBatch: + def __init__(self) -> None: + self.files: dict[str, bytes] = {} + + def put_file(self, file_obj, destination: str) -> None: + self.files[destination] = file_obj.read() + + +@pytest.fixture +def artifacts_rel() -> str: + return "artifacts/run-1" + + +@pytest.fixture +def calibration_package_path() -> Path: + return Path("/pipeline/artifacts/run/calibration_package.pkl") + + +@pytest.fixture +def fake_batch() -> FakeBatch: + return FakeBatch() + + +@pytest.fixture +def regional_result_bytes() -> dict[str, bytes]: + return { + "weights": b"weights", + "geography": b"regional-geo", + "config": b"regional-config", + "log": b"regional-log", + "cal_log": b"regional-epoch", + } + + +@pytest.fixture +def national_result_bytes() -> dict[str, bytes]: + return { + "weights": b"weights", + "geography": b"national-geo", + "config": b"national-config", + "log": b"national-log", + "cal_log": b"national-epoch", + } + + +@pytest.fixture +def regional_output_bundle( + regional_result_bytes: dict[str, bytes], +) -> FittedWeightsOutputBundle: + return FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.REGIONAL, + result_bytes=regional_result_bytes, + run_id="run-1", + ) + + +@pytest.fixture +def national_output_bundle( + national_result_bytes: dict[str, bytes], +) -> FittedWeightsOutputBundle: + return FittedWeightsOutputBundle.from_result_bytes( + scope=FitScope.NATIONAL, + result_bytes=national_result_bytes, + run_id="run-1", + ) + + +@pytest.fixture +def stage_3_substage() -> Callable[[str], dict]: + data = yaml.safe_load(Path("docs/pipeline_map.yaml").read_text()) + substages = {substage["id"]: substage for substage in data["stages"]} + return substages.__getitem__ diff --git a/tests/unit/fit_weights/test_bundles.py b/tests/unit/fit_weights/test_bundles.py index e008e9834..c5985f96d 100644 --- a/tests/unit/fit_weights/test_bundles.py +++ b/tests/unit/fit_weights/test_bundles.py @@ -10,93 +10,72 @@ ) -class FakeBatch: - def __init__(self) -> None: - self.files: dict[str, bytes] = {} - - def put_file(self, file_obj, destination: str) -> None: - self.files[destination] = file_obj.read() - - -def test_input_bundle_exposes_calibration_package_identity_path() -> None: +def test_input_bundle_exposes_calibration_package_identity_path( + calibration_package_path: Path, +) -> None: bundle = FittedWeightsInputBundle( scope="regional", - calibration_package_path=Path( - "/pipeline/artifacts/run/calibration_package.pkl" - ), + calibration_package_path=calibration_package_path, ) assert bundle.scope == FitScope.REGIONAL assert bundle.artifact_identity_paths() == { - "calibration_package": Path("/pipeline/artifacts/run/calibration_package.pkl") + "calibration_package": calibration_package_path } -def test_regional_output_bundle_writes_expected_paths() -> None: - bundle = FittedWeightsOutputBundle.from_result_bytes( - scope=FitScope.REGIONAL, - result_bytes={ - "weights": b"weights", - "geography": b"geo", - "config": b"config", - "log": b"log", - "cal_log": b"epoch", - }, - run_id="run-1", - ) - batch = FakeBatch() - - written = bundle.write_artifacts(batch, "artifacts/run-1") +def test_regional_output_bundle_writes_expected_paths( + artifacts_rel: str, + fake_batch, + regional_output_bundle: FittedWeightsOutputBundle, +) -> None: + written = regional_output_bundle.write_artifacts(fake_batch, artifacts_rel) assert written == [ "artifacts/run-1/calibration_weights.npy", "artifacts/run-1/geography_assignment.npz", "artifacts/run-1/unified_run_config.json", ] - assert batch.files["artifacts/run-1/calibration_weights.npy"] == b"weights" - assert bundle.artifact_paths("/pipeline/artifacts/run-1") == [ + assert fake_batch.files["artifacts/run-1/calibration_weights.npy"] == b"weights" + assert regional_output_bundle.artifact_paths("/pipeline/artifacts/run-1") == [ Path("/pipeline/artifacts/run-1/calibration_weights.npy"), Path("/pipeline/artifacts/run-1/geography_assignment.npz"), Path("/pipeline/artifacts/run-1/unified_run_config.json"), ] -def test_national_output_bundle_writes_expected_paths() -> None: - bundle = FittedWeightsOutputBundle.from_result_bytes( - scope="national", - result_bytes={ - "weights": b"weights", - "geography": b"geo", - "config": b"config", - }, - ) - batch = FakeBatch() - - written = bundle.write_artifacts(batch, "artifacts/run-1") +def test_national_output_bundle_writes_expected_paths( + artifacts_rel: str, + fake_batch, + national_output_bundle: FittedWeightsOutputBundle, +) -> None: + written = national_output_bundle.write_artifacts(fake_batch, artifacts_rel) assert written == [ "artifacts/run-1/national_calibration_weights.npy", "artifacts/run-1/national_geography_assignment.npz", "artifacts/run-1/national_unified_run_config.json", ] - assert batch.files["artifacts/run-1/national_calibration_weights.npy"] == b"weights" + assert ( + fake_batch.files["artifacts/run-1/national_calibration_weights.npy"] + == b"weights" + ) -def test_missing_optional_epoch_log_is_allowed() -> None: +def test_missing_optional_epoch_log_is_allowed( + regional_result_bytes: dict[str, bytes], +) -> None: + result_bytes = dict(regional_result_bytes) + result_bytes.pop("cal_log") bundle = FittedWeightsOutputBundle.from_result_bytes( scope=FitScope.REGIONAL, - result_bytes={ - "weights": b"weights", - "geography": b"geo", - "config": b"config", - "log": b"log", - }, + result_bytes=result_bytes, ) assert bundle.diagnostic_result_bytes() == { - "log": b"log", + "log": b"regional-log", "cal_log": None, - "config": b"config", + "config": b"regional-config", } @@ -118,12 +97,11 @@ def test_missing_weights_is_a_hard_failure() -> None: def test_missing_required_primary_artifacts_fail_before_writes( missing_key: str, expected_role: str, + artifacts_rel: str, + fake_batch, + regional_result_bytes: dict[str, bytes], ) -> None: - result_bytes = { - "weights": b"weights", - "geography": b"geo", - "config": b"config", - } + result_bytes = dict(regional_result_bytes) result_bytes.pop(missing_key) bundle = FittedWeightsOutputBundle.from_result_bytes( scope=FitScope.REGIONAL, @@ -131,32 +109,20 @@ def test_missing_required_primary_artifacts_fail_before_writes( ) with pytest.raises(MissingFitWeightsOutputError, match=expected_role): - bundle.write_artifacts(FakeBatch(), "artifacts/run-1") + bundle.write_artifacts(fake_batch, artifacts_rel) -def test_diagnostics_are_scoped_to_the_output_bundle() -> None: - regional = FittedWeightsOutputBundle.from_result_bytes( - scope=FitScope.REGIONAL, - result_bytes={ - "weights": b"weights", - "geography": b"regional-geo", - "config": b"regional-config", - "log": b"regional-log", - "cal_log": b"regional-epoch", - }, +def test_diagnostics_are_scoped_to_the_output_bundle( + regional_output_bundle: FittedWeightsOutputBundle, + national_output_bundle: FittedWeightsOutputBundle, +) -> None: + assert ( + regional_output_bundle.artifacts.diagnostics.filename + == "unified_diagnostics.csv" ) - national = FittedWeightsOutputBundle.from_result_bytes( - scope=FitScope.NATIONAL, - result_bytes={ - "weights": b"weights", - "geography": b"national-geo", - "config": b"national-config", - "log": b"national-log", - "cal_log": b"national-epoch", - }, + assert ( + national_output_bundle.artifacts.diagnostics.filename + == "national_unified_diagnostics.csv" ) - - assert regional.artifacts.diagnostics.filename == "unified_diagnostics.csv" - assert national.artifacts.diagnostics.filename == "national_unified_diagnostics.csv" - assert regional.diagnostic_result_bytes()["log"] == b"regional-log" - assert national.diagnostic_result_bytes()["log"] == b"national-log" + assert regional_output_bundle.diagnostic_result_bytes()["log"] == b"regional-log" + assert national_output_bundle.diagnostic_result_bytes()["log"] == b"national-log" diff --git a/tests/unit/fit_weights/test_pipeline_docs.py b/tests/unit/fit_weights/test_pipeline_docs.py index ec1b423da..346566814 100644 --- a/tests/unit/fit_weights/test_pipeline_docs.py +++ b/tests/unit/fit_weights/test_pipeline_docs.py @@ -1,18 +1,7 @@ -from pathlib import Path - -import yaml - from policyengine_us_data.fit_weights import FitScope, fit_artifacts_for_scope from scripts.extract_pipeline_docs import scan_decorated_objects -def _substage(substage_id: str) -> dict: - data = yaml.safe_load(Path("docs/pipeline_map.yaml").read_text()) - return next( - substage for substage in data["stages"] if substage["id"] == substage_id - ) - - def test_fit_weights_identity_nodes_are_in_generated_pipeline_docs() -> None: decorated = scan_decorated_objects() @@ -26,11 +15,11 @@ def test_fit_weights_identity_nodes_are_in_generated_pipeline_docs() -> None: ) -def test_stage_3_pipeline_map_labels_match_scoped_artifacts() -> None: +def test_stage_3_pipeline_map_labels_match_scoped_artifacts(stage_3_substage) -> None: regional_artifacts = fit_artifacts_for_scope(FitScope.REGIONAL) national_artifacts = fit_artifacts_for_scope(FitScope.NATIONAL) - regional = _substage("3a_weight_fitting_regional") - national = _substage("3b_weight_fitting_national") + regional = stage_3_substage("3a_weight_fitting_regional") + national = stage_3_substage("3b_weight_fitting_national") regional_nodes = {node["id"]: node for node in regional["extra_nodes"]} national_nodes = {node["id"]: node for node in national["extra_nodes"]} From 8f41e8b6bb2fb169944a7654be11c175f2500bbe Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 16:31:52 +0200 Subject: [PATCH 4/4] Bump policyengine-us dependency --- changelog.d/1045.changed | 1 + 1 file changed, 1 insertion(+) diff --git a/changelog.d/1045.changed b/changelog.d/1045.changed index 2d3014a32..1a5552ef9 100644 --- a/changelog.d/1045.changed +++ b/changelog.d/1045.changed @@ -1 +1,2 @@ Added scoped Stage 3 fitted-weight input and output bundles. +Bumped policyengine-us to 1.701.1.