Skip to content

Commit cc0771d

Browse files
committed
Add typed Stage 2 contract payload schemas
1 parent 13e2559 commit cc0771d

7 files changed

Lines changed: 577 additions & 55 deletions

File tree

policyengine_us_data/calibration/unified_calibration.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
create_target_groups,
4646
)
4747
from policyengine_us_data.pipeline_metadata import pipeline_node
48+
from policyengine_us_data.stage_contracts.calibration_package import (
49+
CalibrationPackageParameters,
50+
)
4851
from policyengine_us_data.pipeline_schema import PipelineNode
4952

5053
logging.basicConfig(
@@ -90,22 +93,21 @@ def _calibration_package_contract_parameters(
9093
chunk_size: int,
9194
parallel: bool,
9295
num_matrix_workers: int,
93-
) -> dict:
96+
) -> CalibrationPackageParameters:
9497
"""Return Stage 2 parameters that affect package construction."""
9598

96-
parallel_matrix = bool(chunked_matrix and parallel)
97-
return {
98-
"workers": workers if not chunked_matrix else None,
99-
"n_clones": n_clones,
100-
"target_config": target_config_path,
101-
"skip_county": skip_county,
102-
"skip_source_impute": skip_source_impute,
103-
"skip_takeup_rerandomize": skip_takeup_rerandomize,
104-
"chunked_matrix": chunked_matrix,
105-
"chunk_size": chunk_size if chunked_matrix else None,
106-
"parallel_matrix": parallel_matrix,
107-
"num_matrix_workers": num_matrix_workers if parallel_matrix else None,
108-
}
99+
return CalibrationPackageParameters.from_runtime_args(
100+
workers=workers,
101+
n_clones=n_clones,
102+
target_config_path=target_config_path,
103+
skip_county=skip_county,
104+
skip_source_impute=skip_source_impute,
105+
skip_takeup_rerandomize=skip_takeup_rerandomize,
106+
chunked_matrix=chunked_matrix,
107+
chunk_size=chunk_size,
108+
parallel=parallel,
109+
num_matrix_workers=num_matrix_workers,
110+
)
109111

110112

111113
def get_git_provenance() -> dict:

policyengine_us_data/stage_contracts/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from .calibration_package import (
1616
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
1717
CALIBRATION_PACKAGE_CONTRACT_TYPE,
18+
CalibrationPackageParameters,
19+
CalibrationPackageSummary,
1820
build_calibration_package_contract,
1921
load_calibration_package_payload,
2022
summarize_calibration_package,
@@ -85,6 +87,8 @@
8587
"CANONICAL_STAGE_IDS",
8688
"CALIBRATION_PACKAGE_CONTRACT_FILENAME",
8789
"CALIBRATION_PACKAGE_CONTRACT_TYPE",
90+
"CalibrationPackageParameters",
91+
"CalibrationPackageSummary",
8892
"CONTRACT_TYPE_BY_STAGE_ID",
8993
"DATASET_BUILD_OUTPUT_CONTRACT_FILENAME",
9094
"DATASET_BUILD_OUTPUT_CONTRACT_TYPE",

policyengine_us_data/stage_contracts/calibration_package.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from policyengine_us_data.utils.step_manifest import sha256_file
1111

1212
from .artifacts import ArtifactRef
13+
from .calibration_package_schema import (
14+
CalibrationPackageParameters,
15+
CalibrationPackageSummary,
16+
)
1317
from .contracts import StageContract
1418
from .execution import ExecutionRecord, ReuseSummary
1519
from .fingerprints import canonicalize_for_fingerprint, fingerprint_material
@@ -24,7 +28,9 @@
2428
CALIBRATION_PACKAGE_SUBSTAGE_ID = "2a_matrix_build_calibration_target_construction"
2529

2630

27-
def summarize_calibration_package(package: Mapping[str, Any]) -> Mapping[str, Any]:
31+
def summarize_calibration_package(
32+
package: Mapping[str, Any],
33+
) -> CalibrationPackageSummary:
2834
"""Return a contract-safe summary of a calibration package pickle payload."""
2935

3036
matrix = _required_package_value(package, "X_sparse")
@@ -44,37 +50,36 @@ def summarize_calibration_package(package: Mapping[str, Any]) -> Mapping[str, An
4450
nnz = int(matrix.nnz)
4551
density = nnz / (n_targets * n_columns) if n_targets * n_columns else 0.0
4652

47-
summary: dict[str, Any] = {
48-
"matrix_shape": (n_targets, n_columns),
49-
"matrix_nnz": nnz,
50-
"matrix_density": float(density),
51-
"n_targets": int(len(targets_df)),
52-
"n_columns": n_columns,
53-
"target_name_count": int(len(target_names)),
54-
"dataset_sha256": _optional_metadata_string(metadata, "dataset_sha256"),
55-
"db_sha256": _optional_metadata_string(metadata, "db_sha256"),
56-
"target_config_path": _optional_metadata_string(
53+
return CalibrationPackageSummary(
54+
matrix_shape=(n_targets, n_columns),
55+
matrix_nnz=nnz,
56+
matrix_density=float(density),
57+
n_targets=int(len(targets_df)),
58+
n_columns=n_columns,
59+
target_name_count=int(len(target_names)),
60+
dataset_sha256=_optional_metadata_string(metadata, "dataset_sha256"),
61+
db_sha256=_optional_metadata_string(metadata, "db_sha256"),
62+
target_config_path=_optional_metadata_string(
5763
metadata,
5864
"target_config_path",
5965
),
60-
"target_config_sha256": _optional_metadata_string(
66+
target_config_sha256=_optional_metadata_string(
6167
metadata,
6268
"target_config_sha256",
6369
),
64-
"n_clones": _optional_metadata_int(metadata, "n_clones"),
65-
"seed": _optional_metadata_int(metadata, "seed"),
66-
"base_n_records": _optional_metadata_int(metadata, "base_n_records"),
67-
"package_scope": _optional_metadata_string(metadata, "package_scope"),
68-
"matrix_builder": _optional_metadata_string(metadata, "matrix_builder"),
69-
"chunk_size": _optional_metadata_int(metadata, "chunk_size"),
70-
"chunk_dir": _optional_metadata_string(metadata, "chunk_dir"),
71-
"has_initial_weights": package.get("initial_weights") is not None,
72-
"has_cd_geoid": package.get("cd_geoid") is not None,
73-
"has_block_geoid": package.get("block_geoid") is not None,
74-
"cd_geoid_length": _optional_len(package.get("cd_geoid")),
75-
"block_geoid_length": _optional_len(package.get("block_geoid")),
76-
}
77-
return summary
70+
n_clones=_optional_metadata_int(metadata, "n_clones"),
71+
seed=_optional_metadata_int(metadata, "seed"),
72+
base_n_records=_optional_metadata_int(metadata, "base_n_records"),
73+
package_scope=_optional_metadata_string(metadata, "package_scope"),
74+
matrix_builder=_optional_metadata_string(metadata, "matrix_builder"),
75+
chunk_size=_optional_metadata_int(metadata, "chunk_size"),
76+
chunk_dir=_optional_metadata_string(metadata, "chunk_dir"),
77+
has_initial_weights=package.get("initial_weights") is not None,
78+
has_cd_geoid=package.get("cd_geoid") is not None,
79+
has_block_geoid=package.get("block_geoid") is not None,
80+
cd_geoid_length=_optional_len(package.get("cd_geoid")),
81+
block_geoid_length=_optional_len(package.get("block_geoid")),
82+
)
7883

7984

8085
def build_calibration_package_contract(
@@ -83,7 +88,7 @@ def build_calibration_package_contract(
8388
dataset_path: Path,
8489
db_path: Path,
8590
package: Mapping[str, Any],
86-
parameters: Mapping[str, Any],
91+
parameters: CalibrationPackageParameters | Mapping[str, Any],
8792
run_id: str | None,
8893
completed_at: str,
8994
started_at: str | None = None,
@@ -100,8 +105,10 @@ def build_calibration_package_contract(
100105
_require_existing_file(dataset_path, "source dataset")
101106
_require_existing_file(db_path, "target database")
102107

108+
parameter_schema = _calibration_package_parameters(parameters)
109+
parameter_payload = parameter_schema.to_dict()
103110
metadata = _package_metadata(package)
104-
package_summary = summarize_calibration_package(package)
111+
package_summary = summarize_calibration_package(package).to_dict()
105112
inputs = (
106113
_artifact_ref_from_path(
107114
logical_name="source_imputed_stratified_extended_cps",
@@ -156,7 +163,7 @@ def build_calibration_package_contract(
156163
"contract_type": CALIBRATION_PACKAGE_CONTRACT_TYPE,
157164
"inputs": inputs,
158165
"outputs": outputs,
159-
"parameters": parameters,
166+
"parameters": parameter_payload,
160167
"package_summary": package_summary,
161168
}
162169
)
@@ -169,15 +176,15 @@ def build_calibration_package_contract(
169176
package_version=package_version,
170177
inputs=inputs,
171178
outputs=outputs,
172-
parameters=parameters,
179+
parameters=parameter_payload,
173180
fingerprint=fingerprint,
174181
substages=(
175182
SubstageRecord(
176183
substage_id=CALIBRATION_PACKAGE_SUBSTAGE_ID,
177184
status="completed",
178185
inputs=inputs,
179186
outputs=outputs,
180-
parameters=parameters,
187+
parameters=parameter_payload,
181188
fingerprint=fingerprint,
182189
reuse_mode="handoff",
183190
),
@@ -197,7 +204,7 @@ def write_calibration_package_contract(
197204
dataset_path: Path,
198205
db_path: Path,
199206
package: Mapping[str, Any],
200-
parameters: Mapping[str, Any],
207+
parameters: CalibrationPackageParameters | Mapping[str, Any],
201208
run_id: str | None,
202209
completed_at: str,
203210
started_at: str | None = None,
@@ -272,10 +279,12 @@ def validate_calibration_package_contract(
272279
raise ValueError("package is required to validate calibration package summary")
273280

274281
expected_summary = canonicalize_for_fingerprint(
275-
summarize_calibration_package(package)
282+
summarize_calibration_package(package).to_dict()
276283
)
277284
actual_summary = canonicalize_for_fingerprint(
278-
contract.metadata.get("package_summary", {})
285+
CalibrationPackageSummary.from_dict(
286+
contract.metadata.get("package_summary", {})
287+
).to_dict()
279288
)
280289
if actual_summary != expected_summary:
281290
raise ValueError("Calibration package contract summary does not match pickle")
@@ -351,6 +360,14 @@ def _optional_len(value: Any) -> int | None:
351360
return int(len(value))
352361

353362

363+
def _calibration_package_parameters(
364+
parameters: CalibrationPackageParameters | Mapping[str, Any],
365+
) -> CalibrationPackageParameters:
366+
if isinstance(parameters, CalibrationPackageParameters):
367+
return parameters
368+
return CalibrationPackageParameters.from_dict(parameters)
369+
370+
354371
def _require_existing_file(path: Path, label: str) -> None:
355372
if not path.exists():
356373
raise FileNotFoundError(f"Missing {label}: {path}")

0 commit comments

Comments
 (0)