Skip to content

Commit 905ca62

Browse files
committed
Add Stage 2 package payload reader
1 parent 5f4be4c commit 905ca62

10 files changed

Lines changed: 630 additions & 219 deletions

File tree

changelog.d/1073.changed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add typed Stage 2 calibration package payload reader and writer helpers.

docs/pipeline_map.yaml

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -821,8 +821,12 @@ stages:
821821
- build_matrix
822822
- build_matrix_chunked
823823
- stage2_artifact_specs
824+
- stage2_payload_boundary
825+
- stage2_payload_writer
826+
- stage2_payload_reader
824827
- stage2_calibration_package_writer
825828
- out_pkg
829+
- out_metadata
826830
- stage2_calibration_package_contract_writer
827831
- out_contract
828832
- stage2_calibration_package_contract_validator
@@ -875,6 +879,10 @@ stages:
875879
label: calibration_package.pkl
876880
node_type: artifact
877881
description: X_sparse CSR matrix, targets_df, initial_weights, metadata
882+
- id: out_metadata
883+
label: calibration_package_meta.json
884+
node_type: artifact
885+
description: Metadata sidecar generated from the typed package payload and Stage 2 contract
878886
- id: out_contract
879887
label: calibration_package_contract.json
880888
node_type: artifact
@@ -983,21 +991,40 @@ stages:
983991
edge_type: uses_library
984992
label: chunked path
985993
- source: build_matrix
986-
target: stage2_calibration_package_writer
994+
target: stage2_payload_boundary
987995
edge_type: data_flow
988996
- source: build_matrix_chunked
989-
target: stage2_calibration_package_writer
997+
target: stage2_payload_boundary
990998
edge_type: data_flow
999+
- source: stage2_payload_boundary
1000+
target: stage2_payload_writer
1001+
edge_type: data_flow
1002+
label: typed pickle payload
1003+
- source: stage2_payload_writer
1004+
target: stage2_calibration_package_writer
1005+
edge_type: uses_library
1006+
label: pickle write
9911007
- source: stage2_artifact_specs
9921008
target: stage2_calibration_package_writer
9931009
edge_type: uses_utility
9941010
label: package path
9951011
- source: stage2_calibration_package_writer
9961012
target: out_pkg
9971013
edge_type: produces_artifact
1014+
- source: stage2_payload_writer
1015+
target: out_metadata
1016+
edge_type: produces_artifact
1017+
label: sidecar metadata
1018+
- source: out_pkg
1019+
target: stage2_payload_reader
1020+
edge_type: data_flow
9981021
- source: out_pkg
9991022
target: stage2_calibration_package_contract_writer
10001023
edge_type: data_flow
1024+
- source: stage2_payload_reader
1025+
target: stage2_calibration_package_contract_writer
1026+
edge_type: uses_library
1027+
label: summary and checksum
10011028
- source: stage2_artifact_specs
10021029
target: stage2_calibration_package_contract_writer
10031030
edge_type: uses_utility

modal_app/remote_calibration_runner.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,23 +321,30 @@ def _print_provenance_from_meta(meta: dict, current_branch: str = None) -> None:
321321

322322

323323
def _write_package_sidecar(pkg_path: str) -> bool:
324-
"""Extract metadata from a pickle package and write a JSON sidecar.
324+
"""Write package metadata from the typed payload and contract sidecar.
325325
326326
Returns:
327327
True if sidecar was written successfully, False otherwise.
328328
"""
329-
import json
330329
import logging
331-
import pickle
332330

333-
sidecar_path = pkg_path.replace(".pkl", "_meta.json")
334331
try:
335-
with open(pkg_path, "rb") as f:
336-
package = pickle.load(f)
337-
meta = package.get("metadata", {})
338-
del package
339-
with open(sidecar_path, "w") as f:
340-
json.dump(meta, f, indent=2)
332+
from policyengine_us_data.calibration_package.payload import (
333+
CalibrationPackageReader,
334+
CalibrationPackageWriter,
335+
)
336+
from policyengine_us_data.calibration_package.specs import (
337+
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
338+
)
339+
from policyengine_us_data.stage_contracts.io import read_contract
340+
341+
package_path = Path(pkg_path)
342+
payload = CalibrationPackageReader(package_path=package_path).read()
343+
contract_path = package_path.with_name(CALIBRATION_PACKAGE_CONTRACT_FILENAME)
344+
contract = read_contract(contract_path) if contract_path.exists() else None
345+
sidecar_path = CalibrationPackageWriter(
346+
package_path=package_path,
347+
).write_metadata_sidecar(payload, contract=contract)
341348
print(
342349
f"Sidecar metadata written to {sidecar_path}",
343350
flush=True,

policyengine_us_data/calibration/unified_calibration.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
from policyengine_us_data.calibration.calibration_utils import (
4545
create_target_groups,
4646
)
47+
from policyengine_us_data.calibration_package.payload import (
48+
CalibrationPackagePayload,
49+
CalibrationPackageReader,
50+
CalibrationPackageWriter,
51+
)
4752
from policyengine_us_data.calibration_package.specs import (
4853
DEFAULT_TARGET_CONFIG_PATH as DEFAULT_TARGET_CONFIG_RELATIVE_PATH,
4954
TargetConfigIdentity,
@@ -677,20 +682,16 @@ def save_calibration_package(
677682
cd_geoid: CD GEOID array from geography assignment.
678683
block_geoid: Block GEOID array from geography assignment.
679684
"""
680-
import pickle
681-
682-
package = {
683-
"X_sparse": X_sparse,
684-
"targets_df": targets_df,
685-
"target_names": target_names,
686-
"metadata": metadata,
687-
"initial_weights": initial_weights,
688-
"cd_geoid": cd_geoid,
689-
"block_geoid": block_geoid,
690-
}
691-
Path(path).parent.mkdir(parents=True, exist_ok=True)
692-
with open(path, "wb") as f:
693-
pickle.dump(package, f, protocol=pickle.HIGHEST_PROTOCOL)
685+
payload = CalibrationPackagePayload(
686+
X_sparse=X_sparse,
687+
targets_df=targets_df,
688+
target_names=target_names,
689+
metadata=metadata,
690+
initial_weights=initial_weights,
691+
cd_geoid=cd_geoid,
692+
block_geoid=block_geoid,
693+
)
694+
CalibrationPackageWriter(package_path=Path(path)).write(payload)
694695
logger.info("Calibration package saved to %s", path)
695696

696697

@@ -703,16 +704,14 @@ def load_calibration_package(path: str) -> dict:
703704
Returns:
704705
Dict with X_sparse, targets_df, target_names, metadata.
705706
"""
706-
import pickle
707-
708-
with open(path, "rb") as f:
709-
package = pickle.load(f)
707+
payload = CalibrationPackageReader(package_path=Path(path)).read()
708+
package = payload.to_mapping()
710709
logger.info(
711710
"Loaded package: %d targets, %d records",
712-
package["X_sparse"].shape[0],
713-
package["X_sparse"].shape[1],
711+
payload.X_sparse.shape[0],
712+
payload.X_sparse.shape[1],
714713
)
715-
meta = package.get("metadata", {})
714+
meta = payload.metadata
716715
print_package_provenance(meta)
717716
check_package_staleness(meta)
718717
return package
@@ -1727,15 +1726,15 @@ def run_calibration(
17271726

17281727
initial_weights = compute_initial_weights(X_sparse, targets_df)
17291728
if package_output_path:
1730-
package_payload = {
1731-
"X_sparse": X_sparse,
1732-
"targets_df": targets_df,
1733-
"target_names": target_names,
1734-
"metadata": metadata,
1735-
"initial_weights": initial_weights,
1736-
"cd_geoid": geography.cd_geoid,
1737-
"block_geoid": geography.block_geoid,
1738-
}
1729+
package_payload = CalibrationPackagePayload(
1730+
X_sparse=X_sparse,
1731+
targets_df=targets_df,
1732+
target_names=target_names,
1733+
metadata=metadata,
1734+
initial_weights=initial_weights,
1735+
cd_geoid=geography.cd_geoid,
1736+
block_geoid=geography.block_geoid,
1737+
)
17391738
save_calibration_package(
17401739
package_output_path,
17411740
X_sparse,

policyengine_us_data/calibration_package/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
stage2_input_bundle_from_stage1_contract,
2727
stage2_input_bundle_from_stage1_contract_path,
2828
)
29+
from .payload import (
30+
LEGACY_MISSING_GEOGRAPHY_WARNING,
31+
REQUIRED_PACKAGE_KEYS,
32+
CalibrationPackagePayload,
33+
CalibrationPackageReader,
34+
CalibrationPackageWriter,
35+
)
2936

3037
__all__ = [
3138
"CALIBRATION_PACKAGE_CONTRACT_FILENAME",
@@ -41,10 +48,15 @@
4148
"TARGET_DATABASE_FILENAME",
4249
"CalibrationPackageArtifactPaths",
4350
"CalibrationPackageOutputBundle",
51+
"CalibrationPackagePayload",
52+
"CalibrationPackageReader",
53+
"CalibrationPackageWriter",
54+
"LEGACY_MISSING_GEOGRAPHY_WARNING",
4455
"Stage2BuildContext",
4556
"Stage2InputBundle",
4657
"Stage2InputBundleError",
4758
"Stage2InputSource",
59+
"REQUIRED_PACKAGE_KEYS",
4860
"TargetConfigIdentity",
4961
"calibration_package_artifact_paths",
5062
"resolve_target_config_identity",

0 commit comments

Comments
 (0)