Skip to content

Commit 5f4be4c

Browse files
committed
Add Stage 2 input artifact bundles
1 parent 903cd12 commit 5f4be4c

9 files changed

Lines changed: 546 additions & 26 deletions

File tree

changelog.d/1065.changed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Stage 2 calibration package construction now resolves its inputs and outputs through run-scoped artifact bundles.

docs/pipeline_map.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,8 @@ stages:
805805
label: run_calibration()
806806
description: 'Build phase: resolve targets and constraints, assemble clone values, and package the sparse calibration matrix'
807807
node_ids:
808+
- stage2_input_bundle
809+
- stage2_build_context
808810
- stage2_target_config_identity
809811
- stage2_target_catalog_load
810812
- target_resolve
@@ -825,6 +827,10 @@ stages:
825827
- out_contract
826828
- stage2_calibration_package_contract_validator
827829
extra_nodes:
830+
- id: in_stage1_contract_s2
831+
label: dataset_build_output.json
832+
node_type: artifact
833+
description: Stage 1 handoff contract preferred for Stage 2 input resolution
828834
- id: in_cps_s5
829835
label: source_imputed_stratified_extended_cps.h5
830836
node_type: artifact
@@ -890,9 +896,34 @@ stages:
890896
node_type: utility
891897
description: CSR/COO matrix construction
892898
edges:
899+
- source: in_stage1_contract_s2
900+
target: stage2_input_bundle
901+
edge_type: data_flow
902+
label: preferred input contract
893903
- source: in_cps_s5
904+
target: stage2_input_bundle
905+
edge_type: data_flow
906+
label: compatibility fallback
907+
- source: in_db_s5
908+
target: stage2_input_bundle
909+
edge_type: external_source
910+
label: compatibility fallback
911+
- source: stage2_input_bundle
912+
target: stage2_build_context
913+
edge_type: data_flow
914+
label: validated inputs
915+
- source: stage2_artifact_specs
916+
target: stage2_build_context
917+
edge_type: uses_utility
918+
label: output bundle paths
919+
- source: stage2_build_context
894920
target: target_resolve
895921
edge_type: data_flow
922+
label: dataset and database paths
923+
- source: stage2_build_context
924+
target: stage2_calibration_package_writer
925+
edge_type: uses_utility
926+
label: package output bundle
896927
- source: in_db_s5
897928
target: target_resolve
898929
edge_type: external_source

modal_app/pipeline.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@
9393
)
9494
from policyengine_us_data.utils.run_context import RunContext, resolve_run_id # noqa: E402
9595
from policyengine_us_data.calibration_package.specs import ( # noqa: E402
96-
calibration_package_artifact_paths,
96+
Stage2InputBundleError,
9797
resolve_target_config_identity,
98+
stage2_build_context_for_run,
9899
)
99100
from policyengine_us_data.utils.error_redaction import ( # noqa: E402
100101
redacted_bounded_error_text,
@@ -1243,14 +1244,13 @@ def run_pipeline(
12431244
print(f" Completed in {completed_build_manifest.duration_s}s")
12441245

12451246
# ── Step 2: Build calibration package ──
1247+
package_context = stage2_build_context_for_run(PIPELINE_MOUNT, run_id)
1248+
package_input_validation = package_context.input_bundle.validation_report()
12461249
package_inputs = _artifact_identities(
1247-
{
1248-
"dataset": _artifacts_dir(run_id)
1249-
/ "source_imputed_stratified_extended_cps.h5",
1250-
"database": _artifacts_dir(run_id) / "policy_data.db",
1251-
}
1250+
package_context.input_bundle.manifest_inputs
12521251
)
1253-
package_artifacts = calibration_package_artifact_paths(_artifacts_dir(run_id))
1252+
package_inputs["input_validation"] = package_input_validation.to_dict()
1253+
package_artifacts = package_context.output_bundle
12541254
package_parameters = _calibration_package_parameters(
12551255
workers=num_workers,
12561256
n_clones=n_clones,
@@ -1261,6 +1261,18 @@ def run_pipeline(
12611261
parallel_matrix=parallel_matrix,
12621262
num_matrix_workers=num_matrix_workers,
12631263
)
1264+
if package_input_validation.status != "pass":
1265+
active_step_manifest = _start_step_manifest(
1266+
meta,
1267+
BUILD_CALIBRATION_PACKAGE,
1268+
parameters=package_parameters,
1269+
input_identities=package_inputs,
1270+
vol=pipeline_volume,
1271+
)
1272+
raise Stage2InputBundleError(
1273+
package_context.input_bundle,
1274+
package_input_validation,
1275+
)
12641276
package_reuse = _step_reusable(
12651277
meta,
12661278
BUILD_CALIBRATION_PACKAGE,

modal_app/remote_calibration_runner.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from modal_app.images import gpu_image as image # noqa: E402
1515
from policyengine_us_data.calibration_package.specs import ( # noqa: E402
1616
calibration_package_artifact_paths,
17+
stage2_build_context_for_run,
1718
)
1819

1920
app = modal.App(
@@ -371,18 +372,13 @@ def _build_package_impl(
371372
_ensure_geography_prerequisites()
372373

373374
pipeline_vol.reload()
374-
artifacts = f"{PIPELINE_MOUNT}/artifacts"
375-
if run_id:
376-
artifacts = f"{artifacts}/{run_id}"
377-
db_path = f"{artifacts}/policy_data.db"
378-
dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5"
379-
for label, p in [("database", db_path), ("dataset", dataset_path)]:
380-
if not os.path.exists(p):
381-
raise RuntimeError(
382-
f"Missing {label} on pipeline volume: {p}. Run data_build first."
383-
)
384-
385-
package_artifacts = calibration_package_artifact_paths(artifacts)
375+
build_context = stage2_build_context_for_run(
376+
PIPELINE_MOUNT, run_id
377+
).require_inputs()
378+
input_bundle = build_context.input_bundle
379+
package_artifacts = build_context.output_bundle
380+
db_path = str(input_bundle.target_database)
381+
dataset_path = str(input_bundle.source_dataset)
386382
pkg_path = str(package_artifacts.package)
387383
cmd = [
388384
*_python_cmd("-m", "policyengine_us_data.calibration.unified_calibration"),

policyengine_us_data/calibration_package/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,51 @@
55
CALIBRATION_PACKAGE_FILENAME,
66
CALIBRATION_PACKAGE_METADATA_FILENAME,
77
CALIBRATION_PACKAGE_SUBSTAGE_ID,
8+
CALIBRATION_REPORTS_DIRNAME,
9+
DATASET_BUILD_OUTPUT_CONTRACT_FILENAME,
810
DEFAULT_TARGET_CONFIG_PATH,
911
MATRIX_BUILD_DIRNAME,
12+
SOURCE_DATASET_FILENAME,
1013
TARGET_CONFIG_IDENTITY_MODES,
14+
TARGET_DATABASE_FILENAME,
1115
CalibrationPackageArtifactPaths,
16+
CalibrationPackageOutputBundle,
17+
Stage2BuildContext,
18+
Stage2InputBundle,
19+
Stage2InputBundleError,
20+
Stage2InputSource,
1221
TargetConfigIdentity,
1322
calibration_package_artifact_paths,
1423
resolve_target_config_identity,
24+
stage2_build_context_for_run,
25+
stage2_input_bundle_from_artifacts_dir,
26+
stage2_input_bundle_from_stage1_contract,
27+
stage2_input_bundle_from_stage1_contract_path,
1528
)
1629

1730
__all__ = [
1831
"CALIBRATION_PACKAGE_CONTRACT_FILENAME",
1932
"CALIBRATION_PACKAGE_FILENAME",
2033
"CALIBRATION_PACKAGE_METADATA_FILENAME",
2134
"CALIBRATION_PACKAGE_SUBSTAGE_ID",
35+
"CALIBRATION_REPORTS_DIRNAME",
36+
"DATASET_BUILD_OUTPUT_CONTRACT_FILENAME",
2237
"DEFAULT_TARGET_CONFIG_PATH",
2338
"MATRIX_BUILD_DIRNAME",
39+
"SOURCE_DATASET_FILENAME",
2440
"TARGET_CONFIG_IDENTITY_MODES",
41+
"TARGET_DATABASE_FILENAME",
2542
"CalibrationPackageArtifactPaths",
43+
"CalibrationPackageOutputBundle",
44+
"Stage2BuildContext",
45+
"Stage2InputBundle",
46+
"Stage2InputBundleError",
47+
"Stage2InputSource",
2648
"TargetConfigIdentity",
2749
"calibration_package_artifact_paths",
2850
"resolve_target_config_identity",
51+
"stage2_build_context_for_run",
52+
"stage2_input_bundle_from_artifacts_dir",
53+
"stage2_input_bundle_from_stage1_contract",
54+
"stage2_input_bundle_from_stage1_contract_path",
2955
]

0 commit comments

Comments
 (0)