Skip to content

Commit 712354b

Browse files
committed
Add Stage 2 package payload reader
1 parent 4544761 commit 712354b

10 files changed

Lines changed: 629 additions & 221 deletions

File tree

changelog.d/1073.changed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add typed Stage 2 calibration package payload reader and writer helpers.

docs/pipeline_map.yaml

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -821,8 +821,12 @@ stages:
821821
- build_matrix
822822
- build_matrix_chunked
823823
- stage2_artifact_specs
824+
- stage2_payload_boundary
825+
- stage2_payload_writer
826+
- stage2_payload_reader
824827
- stage2_calibration_package_writer
825828
- out_pkg
829+
- out_metadata
826830
- stage2_calibration_package_contract_writer
827831
- out_contract
828832
- stage2_calibration_package_contract_validator
@@ -875,6 +879,10 @@ stages:
875879
label: calibration_package.pkl
876880
node_type: artifact
877881
description: X_sparse CSR matrix, targets_df, initial_weights, metadata
882+
- id: out_metadata
883+
label: calibration_package_meta.json
884+
node_type: artifact
885+
description: Metadata sidecar generated from the typed package payload and Stage 2 contract
878886
- id: out_contract
879887
label: calibration_package_contract.json
880888
node_type: artifact
@@ -983,21 +991,40 @@ stages:
983991
edge_type: uses_library
984992
label: chunked path
985993
- source: build_matrix
986-
target: stage2_calibration_package_writer
994+
target: stage2_payload_boundary
987995
edge_type: data_flow
988996
- source: build_matrix_chunked
989-
target: stage2_calibration_package_writer
997+
target: stage2_payload_boundary
990998
edge_type: data_flow
999+
- source: stage2_payload_boundary
1000+
target: stage2_payload_writer
1001+
edge_type: data_flow
1002+
label: typed pickle payload
1003+
- source: stage2_payload_writer
1004+
target: stage2_calibration_package_writer
1005+
edge_type: uses_library
1006+
label: pickle write
9911007
- source: stage2_artifact_specs
9921008
target: stage2_calibration_package_writer
9931009
edge_type: uses_utility
9941010
label: package path
9951011
- source: stage2_calibration_package_writer
9961012
target: out_pkg
9971013
edge_type: produces_artifact
1014+
- source: stage2_payload_writer
1015+
target: out_metadata
1016+
edge_type: produces_artifact
1017+
label: sidecar metadata
1018+
- source: out_pkg
1019+
target: stage2_payload_reader
1020+
edge_type: data_flow
9981021
- source: out_pkg
9991022
target: stage2_calibration_package_contract_writer
10001023
edge_type: data_flow
1024+
- source: stage2_payload_reader
1025+
target: stage2_calibration_package_contract_writer
1026+
edge_type: uses_library
1027+
label: summary and checksum
10011028
- source: stage2_artifact_specs
10021029
target: stage2_calibration_package_contract_writer
10031030
edge_type: uses_utility

modal_app/remote_calibration_runner.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,23 +363,30 @@ def _print_provenance_from_meta(meta: dict, current_branch: str = None) -> None:
363363

364364

365365
def _write_package_sidecar(pkg_path: str) -> bool:
366-
"""Extract metadata from a pickle package and write a JSON sidecar.
366+
"""Write package metadata from the typed payload and contract sidecar.
367367
368368
Returns:
369369
True if sidecar was written successfully, False otherwise.
370370
"""
371-
import json
372371
import logging
373-
import pickle
374372

375-
sidecar_path = pkg_path.replace(".pkl", "_meta.json")
376373
try:
377-
with open(pkg_path, "rb") as f:
378-
package = pickle.load(f)
379-
meta = package.get("metadata", {})
380-
del package
381-
with open(sidecar_path, "w") as f:
382-
json.dump(meta, f, indent=2)
374+
from policyengine_us_data.calibration_package.payload import (
375+
CalibrationPackageReader,
376+
CalibrationPackageWriter,
377+
)
378+
from policyengine_us_data.calibration_package.specs import (
379+
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
380+
)
381+
from policyengine_us_data.stage_contracts.io import read_contract
382+
383+
package_path = Path(pkg_path)
384+
payload = CalibrationPackageReader(package_path=package_path).read()
385+
contract_path = package_path.with_name(CALIBRATION_PACKAGE_CONTRACT_FILENAME)
386+
contract = read_contract(contract_path) if contract_path.exists() else None
387+
sidecar_path = CalibrationPackageWriter(
388+
package_path=package_path,
389+
).write_metadata_sidecar(payload, contract=contract)
383390
print(
384391
f"Sidecar metadata written to {sidecar_path}",
385392
flush=True,

policyengine_us_data/calibration/unified_calibration.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
build_checkpoint_signature,
4242
checkpoint_signature_mismatches,
4343
)
44-
from policyengine_us_data.calibration.calibration_utils import (
45-
create_target_groups,
44+
from policyengine_us_data.calibration_package.payload import (
45+
CalibrationPackagePayload,
46+
CalibrationPackageReader,
47+
CalibrationPackageWriter,
4648
)
4749
from policyengine_us_data.calibration_package.specs import (
4850
DEFAULT_TARGET_CONFIG_PATH as DEFAULT_TARGET_CONFIG_RELATIVE_PATH,
@@ -680,20 +682,16 @@ def save_calibration_package(
680682
cd_geoid: CD GEOID array from geography assignment.
681683
block_geoid: Block GEOID array from geography assignment.
682684
"""
683-
import pickle
684-
685-
package = {
686-
"X_sparse": X_sparse,
687-
"targets_df": targets_df,
688-
"target_names": target_names,
689-
"metadata": metadata,
690-
"initial_weights": initial_weights,
691-
"cd_geoid": cd_geoid,
692-
"block_geoid": block_geoid,
693-
}
694-
Path(path).parent.mkdir(parents=True, exist_ok=True)
695-
with open(path, "wb") as f:
696-
pickle.dump(package, f, protocol=pickle.HIGHEST_PROTOCOL)
685+
payload = CalibrationPackagePayload(
686+
X_sparse=X_sparse,
687+
targets_df=targets_df,
688+
target_names=target_names,
689+
metadata=metadata,
690+
initial_weights=initial_weights,
691+
cd_geoid=cd_geoid,
692+
block_geoid=block_geoid,
693+
)
694+
CalibrationPackageWriter(package_path=Path(path)).write(payload)
697695
logger.info("Calibration package saved to %s", path)
698696

699697

@@ -706,16 +704,14 @@ def load_calibration_package(path: str) -> dict:
706704
Returns:
707705
Dict with X_sparse, targets_df, target_names, metadata.
708706
"""
709-
import pickle
710-
711-
with open(path, "rb") as f:
712-
package = pickle.load(f)
707+
payload = CalibrationPackageReader(package_path=Path(path)).read()
708+
package = payload.to_mapping()
713709
logger.info(
714710
"Loaded package: %d targets, %d records",
715-
package["X_sparse"].shape[0],
716-
package["X_sparse"].shape[1],
711+
payload.X_sparse.shape[0],
712+
payload.X_sparse.shape[1],
717713
)
718-
meta = package.get("metadata", {})
714+
meta = payload.metadata
719715
print_package_provenance(meta)
720716
check_package_staleness(meta)
721717
return package
@@ -1732,15 +1728,15 @@ def run_calibration(
17321728

17331729
initial_weights = compute_initial_weights(X_sparse, targets_df)
17341730
if package_output_path:
1735-
package_payload = {
1736-
"X_sparse": X_sparse,
1737-
"targets_df": targets_df,
1738-
"target_names": target_names,
1739-
"metadata": metadata,
1740-
"initial_weights": initial_weights,
1741-
"cd_geoid": geography.cd_geoid,
1742-
"block_geoid": geography.block_geoid,
1743-
}
1731+
package_payload = CalibrationPackagePayload(
1732+
X_sparse=X_sparse,
1733+
targets_df=targets_df,
1734+
target_names=target_names,
1735+
metadata=metadata,
1736+
initial_weights=initial_weights,
1737+
cd_geoid=geography.cd_geoid,
1738+
block_geoid=geography.block_geoid,
1739+
)
17441740
save_calibration_package(
17451741
package_output_path,
17461742
X_sparse,

policyengine_us_data/calibration_package/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
stage2_input_bundle_from_stage1_contract,
2727
stage2_input_bundle_from_stage1_contract_path,
2828
)
29+
from .payload import (
30+
LEGACY_MISSING_GEOGRAPHY_WARNING,
31+
REQUIRED_PACKAGE_KEYS,
32+
CalibrationPackagePayload,
33+
CalibrationPackageReader,
34+
CalibrationPackageWriter,
35+
)
2936

3037
__all__ = [
3138
"CALIBRATION_PACKAGE_CONTRACT_FILENAME",
@@ -41,10 +48,15 @@
4148
"TARGET_DATABASE_FILENAME",
4249
"CalibrationPackageArtifactPaths",
4350
"CalibrationPackageOutputBundle",
51+
"CalibrationPackagePayload",
52+
"CalibrationPackageReader",
53+
"CalibrationPackageWriter",
54+
"LEGACY_MISSING_GEOGRAPHY_WARNING",
4455
"Stage2BuildContext",
4556
"Stage2InputBundle",
4657
"Stage2InputBundleError",
4758
"Stage2InputSource",
59+
"REQUIRED_PACKAGE_KEYS",
4860
"TargetConfigIdentity",
4961
"calibration_package_artifact_paths",
5062
"resolve_target_config_identity",

0 commit comments

Comments
 (0)