1010from policyengine_us_data .utils .step_manifest import sha256_file
1111
1212from .artifacts import ArtifactRef
13+ from .calibration_package_schema import (
14+ CalibrationPackageParameters ,
15+ CalibrationPackageSummary ,
16+ )
1317from .contracts import StageContract
1418from .execution import ExecutionRecord , ReuseSummary
1519from .fingerprints import canonicalize_for_fingerprint , fingerprint_material
2428CALIBRATION_PACKAGE_SUBSTAGE_ID = "2a_matrix_build_calibration_target_construction"
2529
2630
27- def summarize_calibration_package (package : Mapping [str , Any ]) -> Mapping [str , Any ]:
31+ def summarize_calibration_package (
32+ package : Mapping [str , Any ],
33+ ) -> CalibrationPackageSummary :
2834 """Return a contract-safe summary of a calibration package pickle payload."""
2935
3036 matrix = _required_package_value (package , "X_sparse" )
@@ -44,37 +50,36 @@ def summarize_calibration_package(package: Mapping[str, Any]) -> Mapping[str, An
4450 nnz = int (matrix .nnz )
4551 density = nnz / (n_targets * n_columns ) if n_targets * n_columns else 0.0
4652
47- summary : dict [ str , Any ] = {
48- " matrix_shape" : (n_targets , n_columns ),
49- " matrix_nnz" : nnz ,
50- " matrix_density" : float (density ),
51- " n_targets" : int (len (targets_df )),
52- " n_columns" : n_columns ,
53- " target_name_count" : int (len (target_names )),
54- " dataset_sha256" : _optional_metadata_string (metadata , "dataset_sha256" ),
55- " db_sha256" : _optional_metadata_string (metadata , "db_sha256" ),
56- " target_config_path" : _optional_metadata_string (
53+ return CalibrationPackageSummary (
54+ matrix_shape = (n_targets , n_columns ),
55+ matrix_nnz = nnz ,
56+ matrix_density = float (density ),
57+ n_targets = int (len (targets_df )),
58+ n_columns = n_columns ,
59+ target_name_count = int (len (target_names )),
60+ dataset_sha256 = _optional_metadata_string (metadata , "dataset_sha256" ),
61+ db_sha256 = _optional_metadata_string (metadata , "db_sha256" ),
62+ target_config_path = _optional_metadata_string (
5763 metadata ,
5864 "target_config_path" ,
5965 ),
60- " target_config_sha256" : _optional_metadata_string (
66+ target_config_sha256 = _optional_metadata_string (
6167 metadata ,
6268 "target_config_sha256" ,
6369 ),
64- "n_clones" : _optional_metadata_int (metadata , "n_clones" ),
65- "seed" : _optional_metadata_int (metadata , "seed" ),
66- "base_n_records" : _optional_metadata_int (metadata , "base_n_records" ),
67- "package_scope" : _optional_metadata_string (metadata , "package_scope" ),
68- "matrix_builder" : _optional_metadata_string (metadata , "matrix_builder" ),
69- "chunk_size" : _optional_metadata_int (metadata , "chunk_size" ),
70- "chunk_dir" : _optional_metadata_string (metadata , "chunk_dir" ),
71- "has_initial_weights" : package .get ("initial_weights" ) is not None ,
72- "has_cd_geoid" : package .get ("cd_geoid" ) is not None ,
73- "has_block_geoid" : package .get ("block_geoid" ) is not None ,
74- "cd_geoid_length" : _optional_len (package .get ("cd_geoid" )),
75- "block_geoid_length" : _optional_len (package .get ("block_geoid" )),
76- }
77- return summary
70+ n_clones = _optional_metadata_int (metadata , "n_clones" ),
71+ seed = _optional_metadata_int (metadata , "seed" ),
72+ base_n_records = _optional_metadata_int (metadata , "base_n_records" ),
73+ package_scope = _optional_metadata_string (metadata , "package_scope" ),
74+ matrix_builder = _optional_metadata_string (metadata , "matrix_builder" ),
75+ chunk_size = _optional_metadata_int (metadata , "chunk_size" ),
76+ chunk_dir = _optional_metadata_string (metadata , "chunk_dir" ),
77+ has_initial_weights = package .get ("initial_weights" ) is not None ,
78+ has_cd_geoid = package .get ("cd_geoid" ) is not None ,
79+ has_block_geoid = package .get ("block_geoid" ) is not None ,
80+ cd_geoid_length = _optional_len (package .get ("cd_geoid" )),
81+ block_geoid_length = _optional_len (package .get ("block_geoid" )),
82+ )
7883
7984
8085def build_calibration_package_contract (
@@ -83,7 +88,7 @@ def build_calibration_package_contract(
8388 dataset_path : Path ,
8489 db_path : Path ,
8590 package : Mapping [str , Any ],
86- parameters : Mapping [str , Any ],
91+ parameters : CalibrationPackageParameters | Mapping [str , Any ],
8792 run_id : str | None ,
8893 completed_at : str ,
8994 started_at : str | None = None ,
@@ -100,8 +105,10 @@ def build_calibration_package_contract(
100105 _require_existing_file (dataset_path , "source dataset" )
101106 _require_existing_file (db_path , "target database" )
102107
108+ parameter_schema = _calibration_package_parameters (parameters )
109+ parameter_payload = parameter_schema .to_dict ()
103110 metadata = _package_metadata (package )
104- package_summary = summarize_calibration_package (package )
111+ package_summary = summarize_calibration_package (package ). to_dict ()
105112 inputs = (
106113 _artifact_ref_from_path (
107114 logical_name = "source_imputed_stratified_extended_cps" ,
@@ -156,7 +163,7 @@ def build_calibration_package_contract(
156163 "contract_type" : CALIBRATION_PACKAGE_CONTRACT_TYPE ,
157164 "inputs" : inputs ,
158165 "outputs" : outputs ,
159- "parameters" : parameters ,
166+ "parameters" : parameter_payload ,
160167 "package_summary" : package_summary ,
161168 }
162169 )
@@ -169,15 +176,15 @@ def build_calibration_package_contract(
169176 package_version = package_version ,
170177 inputs = inputs ,
171178 outputs = outputs ,
172- parameters = parameters ,
179+ parameters = parameter_payload ,
173180 fingerprint = fingerprint ,
174181 substages = (
175182 SubstageRecord (
176183 substage_id = CALIBRATION_PACKAGE_SUBSTAGE_ID ,
177184 status = "completed" ,
178185 inputs = inputs ,
179186 outputs = outputs ,
180- parameters = parameters ,
187+ parameters = parameter_payload ,
181188 fingerprint = fingerprint ,
182189 reuse_mode = "handoff" ,
183190 ),
@@ -197,7 +204,7 @@ def write_calibration_package_contract(
197204 dataset_path : Path ,
198205 db_path : Path ,
199206 package : Mapping [str , Any ],
200- parameters : Mapping [str , Any ],
207+ parameters : CalibrationPackageParameters | Mapping [str , Any ],
201208 run_id : str | None ,
202209 completed_at : str ,
203210 started_at : str | None = None ,
@@ -272,10 +279,12 @@ def validate_calibration_package_contract(
272279 raise ValueError ("package is required to validate calibration package summary" )
273280
274281 expected_summary = canonicalize_for_fingerprint (
275- summarize_calibration_package (package )
282+ summarize_calibration_package (package ). to_dict ()
276283 )
277284 actual_summary = canonicalize_for_fingerprint (
278- contract .metadata .get ("package_summary" , {})
285+ CalibrationPackageSummary .from_dict (
286+ contract .metadata .get ("package_summary" , {})
287+ ).to_dict ()
279288 )
280289 if actual_summary != expected_summary :
281290 raise ValueError ("Calibration package contract summary does not match pickle" )
@@ -351,6 +360,14 @@ def _optional_len(value: Any) -> int | None:
351360 return int (len (value ))
352361
353362
363+ def _calibration_package_parameters (
364+ parameters : CalibrationPackageParameters | Mapping [str , Any ],
365+ ) -> CalibrationPackageParameters :
366+ if isinstance (parameters , CalibrationPackageParameters ):
367+ return parameters
368+ return CalibrationPackageParameters .from_dict (parameters )
369+
370+
354371def _require_existing_file (path : Path , label : str ) -> None :
355372 if not path .exists ():
356373 raise FileNotFoundError (f"Missing { label } : { path } " )
0 commit comments