Skip to content

Commit 903cd12

Browse files
committed
Add Stage 2 package identity specs
1 parent 710ba7a commit 903cd12

16 files changed

Lines changed: 759 additions & 30 deletions

changelog.d/1041.changed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Stage 2 calibration package manifests now track the explicit target config identity and contract artifact path.

docs/pipeline_map.yaml

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,15 +805,25 @@ stages:
805805
label: run_calibration()
806806
description: 'Build phase: resolve targets and constraints, assemble clone values, and package the sparse calibration matrix'
807807
node_ids:
808+
- stage2_target_config_identity
809+
- stage2_target_catalog_load
808810
- target_resolve
811+
- stage2_target_config_apply
809812
- target_uprate
810813
- geo_build
811814
- constraint_resolve
812815
- state_precomp
813816
- clone_assembly
814817
- takeup_rerand
815818
- sparse_build
819+
- build_matrix
820+
- build_matrix_chunked
821+
- stage2_artifact_specs
822+
- stage2_calibration_package_writer
816823
- out_pkg
824+
- stage2_calibration_package_contract_writer
825+
- out_contract
826+
- stage2_calibration_package_contract_validator
817827
extra_nodes:
818828
- id: in_cps_s5
819829
label: source_imputed_stratified_extended_cps.h5
@@ -859,6 +869,10 @@ stages:
859869
label: calibration_package.pkl
860870
node_type: artifact
861871
description: X_sparse CSR matrix, targets_df, initial_weights, metadata
872+
- id: out_contract
873+
label: calibration_package_contract.json
874+
node_type: artifact
875+
description: Stage 2 package handoff contract written next to calibration_package.pkl
862876
- id: util_sql
863877
label: sqlalchemy
864878
node_type: utility
@@ -884,12 +898,25 @@ stages:
884898
edge_type: external_source
885899
label: SQL targets
886900
- source: in_config_s5
887-
target: target_resolve
901+
target: stage2_target_config_identity
902+
edge_type: data_flow
903+
label: config file
904+
- source: stage2_target_config_identity
905+
target: stage2_target_catalog_load
906+
edge_type: data_flow
907+
label: resolved path and checksum
908+
- source: stage2_target_catalog_load
909+
target: stage2_target_config_apply
888910
edge_type: data_flow
889-
label: include list
911+
label: include/exclude rules
890912
- source: target_resolve
913+
target: stage2_target_config_apply
914+
edge_type: data_flow
915+
label: candidate targets
916+
- source: stage2_target_config_apply
891917
target: target_uprate
892918
edge_type: data_flow
919+
label: selected targets
893920
- source: target_uprate
894921
target: geo_build
895922
edge_type: data_flow
@@ -917,8 +944,48 @@ stages:
917944
target: sparse_build
918945
edge_type: data_flow
919946
- source: sparse_build
947+
target: build_matrix
948+
edge_type: uses_library
949+
label: non-chunked path
950+
- source: sparse_build
951+
target: build_matrix_chunked
952+
edge_type: uses_library
953+
label: chunked path
954+
- source: build_matrix
955+
target: stage2_calibration_package_writer
956+
edge_type: data_flow
957+
- source: build_matrix_chunked
958+
target: stage2_calibration_package_writer
959+
edge_type: data_flow
960+
- source: stage2_artifact_specs
961+
target: stage2_calibration_package_writer
962+
edge_type: uses_utility
963+
label: package path
964+
- source: stage2_calibration_package_writer
920965
target: out_pkg
921966
edge_type: produces_artifact
967+
- source: out_pkg
968+
target: stage2_calibration_package_contract_writer
969+
edge_type: data_flow
970+
- source: stage2_artifact_specs
971+
target: stage2_calibration_package_contract_writer
972+
edge_type: uses_utility
973+
label: contract path
974+
- source: stage2_calibration_package_contract_writer
975+
target: out_contract
976+
edge_type: produces_artifact
977+
- source: out_pkg
978+
target: stage2_calibration_package_contract_validator
979+
edge_type: validates
980+
- source: out_contract
981+
target: stage2_calibration_package_contract_validator
982+
edge_type: validates
983+
- source: in_cps_s5
984+
target: stage2_calibration_package_contract_validator
985+
edge_type: validates
986+
- source: in_db_s5
987+
target: stage2_calibration_package_contract_validator
988+
edge_type: validates
922989
- source: util_sql
923990
target: target_resolve
924991
edge_type: uses_utility

modal_app/pipeline.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
write_run_meta,
9393
)
9494
from policyengine_us_data.utils.run_context import RunContext, resolve_run_id # noqa: E402
95+
from policyengine_us_data.calibration_package.specs import ( # noqa: E402
96+
calibration_package_artifact_paths,
97+
resolve_target_config_identity,
98+
)
9599
from policyengine_us_data.utils.error_redaction import ( # noqa: E402
96100
redacted_bounded_error_text,
97101
redact_error_text,
@@ -162,18 +166,25 @@ def _calibration_package_parameters(
162166
workers: int,
163167
n_clones: int,
164168
target_config: str | None,
169+
all_active_targets: bool = False,
165170
skip_county: bool,
166171
chunked_matrix: bool,
167172
chunk_size: int,
168173
parallel_matrix: bool,
169174
num_matrix_workers: int,
170175
) -> dict:
171176
"""Return manifest parameters that affect package construction."""
177+
target_config_identity = resolve_target_config_identity(
178+
target_config,
179+
all_active_targets=all_active_targets,
180+
)
172181
effective_parallel = bool(chunked_matrix and parallel_matrix)
173182
params = {
174183
"workers": workers if not chunked_matrix else None,
175184
"n_clones": n_clones,
176-
"target_config": target_config,
185+
"target_config": target_config_identity.path,
186+
"target_config_sha256": target_config_identity.sha256,
187+
"target_config_mode": target_config_identity.mode,
177188
"skip_county": skip_county,
178189
"chunked_matrix": bool(chunked_matrix),
179190
"chunk_size": chunk_size if chunked_matrix else None,
@@ -547,6 +558,7 @@ def verify_runtime_seams() -> dict:
547558
"modal_app/step_manifests/errors.py",
548559
"modal_app/step_manifests/status.py",
549560
"modal_app/fixtures/h5_cases.py",
561+
"policyengine_us_data/calibration_package/specs.py",
550562
"tests/integration/test_fixture_50hh.h5",
551563
"policyengine_us_data/calibration/target_config.yaml",
552564
"policyengine_us_data/calibration/target_config_full.yaml",
@@ -1238,6 +1250,7 @@ def run_pipeline(
12381250
"database": _artifacts_dir(run_id) / "policy_data.db",
12391251
}
12401252
)
1253+
package_artifacts = calibration_package_artifact_paths(_artifacts_dir(run_id))
12411254
package_parameters = _calibration_package_parameters(
12421255
workers=num_workers,
12431256
n_clones=n_clones,
@@ -1302,8 +1315,7 @@ def run_pipeline(
13021315
completed_package_manifest = _complete_step_manifest(
13031316
active_step_manifest,
13041317
outputs=collect_artifacts(
1305-
[_artifacts_dir(run_id) / "calibration_package.pkl"],
1306-
missing_ok=True,
1318+
package_artifacts.manifest_outputs,
13071319
),
13081320
vol=pipeline_volume,
13091321
)

modal_app/remote_calibration_runner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
sys.path.insert(0, _p)
1313

1414
from modal_app.images import gpu_image as image # noqa: E402
15+
from policyengine_us_data.calibration_package.specs import ( # noqa: E402
16+
calibration_package_artifact_paths,
17+
)
1518

1619
app = modal.App(
1720
os.environ.get("US_DATA_FIT_WEIGHTS_APP_NAME") or "policyengine-us-data-fit-weights"
@@ -379,7 +382,8 @@ def _build_package_impl(
379382
f"Missing {label} on pipeline volume: {p}. Run data_build first."
380383
)
381384

382-
pkg_path = f"{artifacts}/calibration_package.pkl"
385+
package_artifacts = calibration_package_artifact_paths(artifacts)
386+
pkg_path = str(package_artifacts.package)
383387
cmd = [
384388
*_python_cmd("-m", "policyengine_us_data.calibration.unified_calibration"),
385389
"--device",
@@ -404,7 +408,7 @@ def _build_package_impl(
404408
if chunked_matrix:
405409
cmd.extend(["--chunked-matrix", "--chunk-size", str(chunk_size)])
406410
if parallel_matrix:
407-
chunk_dir = f"{artifacts}/matrix_build"
411+
chunk_dir = str(package_artifacts.matrix_build_dir)
408412
cmd.extend(
409413
[
410414
"--parallel",
@@ -439,14 +443,12 @@ def _build_package_impl(
439443
raise RuntimeError(f"Package build failed with code {build_rc}")
440444

441445
from policyengine_us_data.stage_contracts.calibration_package import (
442-
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
443446
validate_persisted_calibration_package_contract,
444447
)
445448

446-
contract_path = f"{artifacts}/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}"
447449
validate_persisted_calibration_package_contract(
448-
package_path=Path(pkg_path),
449-
contract_path=Path(contract_path),
450+
package_path=package_artifacts.package,
451+
contract_path=package_artifacts.contract,
450452
dataset_path=Path(dataset_path),
451453
db_path=Path(db_path),
452454
)
@@ -525,8 +527,9 @@ def check_volume_package(artifacts_dir: str = "") -> dict:
525527
import json
526528

527529
base = artifacts_dir if artifacts_dir else f"{PIPELINE_MOUNT}/artifacts"
528-
pkg_path = f"{base}/calibration_package.pkl"
529-
sidecar_path = f"{base}/calibration_package_meta.json"
530+
package_artifacts = calibration_package_artifact_paths(base)
531+
pkg_path = str(package_artifacts.package)
532+
sidecar_path = str(package_artifacts.metadata)
530533
if not os.path.exists(pkg_path):
531534
return {"exists": False}
532535

0 commit comments

Comments
 (0)