Skip to content

Commit cd92e21

Browse files
authored
Merge pull request #1043 from PolicyEngine/agent/stage-3/pr-3a-specs-artifact-identity
Centralize Stage 3 fit specs and artifacts
2 parents a9a1e2a + a52a0e1 commit cd92e21

15 files changed

Lines changed: 896 additions & 80 deletions

File tree

changelog.d/1040.changed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Centralized Stage 3 fitted-weight specs and artifact filenames.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Stage 3: Fit Weights
2+
3+
Stage 3 produces scoped fitted-weight artifacts for regional and national H5
4+
builds. The public identity boundary lives in `policyengine_us_data.fit_weights`:
5+
6+
- `FitScope` names the durable regional and national scopes.
7+
- `FittedWeightsSpec` defines the scoped optimization parameters recorded in
8+
step manifests for reuse decisions.
9+
- `ScopedFitArtifacts` defines the artifact filenames written by the Modal fit
10+
step and consumed by downstream H5 builders.
11+
12+
The current artifact names remain behavior-compatible:
13+
14+
- regional: `calibration_weights.npy`, `geography_assignment.npz`,
15+
`unified_run_config.json`, `unified_diagnostics.csv`, and
16+
`calibration_log.csv`;
17+
- national: `national_calibration_weights.npy`,
18+
`national_geography_assignment.npz`, `national_unified_run_config.json`,
19+
`national_unified_diagnostics.csv`, and `national_calibration_log.csv`.
20+
21+
When changing Stage 3 fitting parameters, artifact names, or scope behavior,
22+
update the central specs first and then adapt Modal callers to consume those
23+
specs. Do not add parallel filename constants in orchestration code.

docs/pipeline_map.yaml

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,8 @@ stages:
947947
label: run_calibration()
948948
description: 'Fit phase: initialize weights, optimize sparse calibration weights, and emit artifacts plus diagnostics'
949949
node_ids:
950+
- fit_spec_regional
951+
- fit_artifacts_regional
950952
- init_weights
951953
- create_model
952954
- fit_model
@@ -964,6 +966,14 @@ stages:
964966
label: Modal GPU Container
965967
node_type: external
966968
description: T4 / A10 / A100 / H100 - 32GB RAM, 8 CPU
969+
- id: fit_spec_regional
970+
label: FittedWeightsSpec regional
971+
node_type: library
972+
description: Regional Stage 3 fit hyperparameters and deterministic reuse identity
973+
- id: fit_artifacts_regional
974+
label: ScopedFitArtifacts regional
975+
node_type: library
976+
description: Regional fitted-weight artifact filenames and remote result mapping
967977
- id: create_model
968978
label: Create SparseCalibrationWeights
969979
node_type: process
@@ -977,7 +987,7 @@ stages:
977987
node_type: artifact
978988
description: 'Shape: (n_records x n_clones) - most entries zero'
979989
- id: out_geo_s6
980-
label: geography.npz
990+
label: geography_assignment.npz
981991
node_type: artifact
982992
description: block_geoid, cd_geoid, county_fips, state_fips
983993
- id: out_diag
@@ -1000,6 +1010,21 @@ stages:
10001010
- source: in_pkg_s6
10011011
target: init_weights
10021012
edge_type: data_flow
1013+
- source: fit_spec_regional
1014+
target: fit_model
1015+
edge_type: uses_library
1016+
- source: fit_artifacts_regional
1017+
target: out_weights
1018+
edge_type: documents
1019+
- source: fit_artifacts_regional
1020+
target: out_geo_s6
1021+
edge_type: documents
1022+
- source: fit_artifacts_regional
1023+
target: out_diag
1024+
edge_type: documents
1025+
- source: fit_artifacts_regional
1026+
target: out_config_s6
1027+
edge_type: documents
10031028
- source: init_weights
10041029
target: create_model
10051030
edge_type: data_flow
@@ -1047,6 +1072,8 @@ stages:
10471072
label: run_calibration() national
10481073
description: 'National fit phase: initialize national weights, optimize sparse calibration weights, and emit national artifacts plus diagnostics'
10491074
node_ids:
1075+
- fit_spec_national
1076+
- fit_artifacts_national
10501077
- init_weights
10511078
- create_model_national
10521079
- fit_model
@@ -1064,6 +1091,14 @@ stages:
10641091
label: Modal GPU Container
10651092
node_type: external
10661093
description: T4 / A10 / A100 / H100 for national calibration
1094+
- id: fit_spec_national
1095+
label: FittedWeightsSpec national
1096+
node_type: library
1097+
description: National Stage 3 fit hyperparameters and deterministic reuse identity
1098+
- id: fit_artifacts_national
1099+
label: ScopedFitArtifacts national
1100+
node_type: library
1101+
description: National fitted-weight artifact filenames and remote result mapping
10671102
- id: create_model_national
10681103
label: Create National SparseCalibrationWeights
10691104
node_type: process
@@ -1100,6 +1135,21 @@ stages:
11001135
- source: in_pkg_national_s6
11011136
target: init_weights
11021137
edge_type: data_flow
1138+
- source: fit_spec_national
1139+
target: fit_model
1140+
edge_type: uses_library
1141+
- source: fit_artifacts_national
1142+
target: out_national_weights
1143+
edge_type: documents
1144+
- source: fit_artifacts_national
1145+
target: out_national_geo_s6
1146+
edge_type: documents
1147+
- source: fit_artifacts_national
1148+
target: out_national_diag
1149+
edge_type: documents
1150+
- source: fit_artifacts_national
1151+
target: out_national_config_s6
1152+
edge_type: documents
11031153
- source: init_weights
11041154
target: create_model_national
11051155
edge_type: data_flow

modal_app/local_area.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@
5555
from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402
5656
WorkerCalibrationInputs,
5757
)
58+
from policyengine_us_data.fit_weights import ( # noqa: E402
59+
FitScope,
60+
fit_artifacts_for_scope,
61+
)
5862
from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402
5963
from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402
6064
from policyengine_us_data.utils.run_context import ( # noqa: E402
@@ -1310,11 +1314,12 @@ def coordinate_publish(
13101314
artifacts = (
13111315
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
13121316
)
1313-
weights_path = artifacts / "calibration_weights.npy"
1314-
geography_path = artifacts / "geography_assignment.npz"
1317+
regional_fit_artifacts = fit_artifacts_for_scope(FitScope.REGIONAL)
1318+
weights_path = artifacts / regional_fit_artifacts.weights.filename
1319+
geography_path = artifacts / regional_fit_artifacts.geography.filename
13151320
db_path = artifacts / "policy_data.db"
13161321
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
1317-
config_json_path = artifacts / "unified_run_config.json"
1322+
config_json_path = artifacts / regional_fit_artifacts.run_config.filename
13181323
calibration_package_path = artifacts / "calibration_package.pkl"
13191324

13201325
required = {
@@ -1609,11 +1614,13 @@ def coordinate_national_publish(
16091614
artifacts = (
16101615
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
16111616
)
1612-
weights_path = artifacts / "national_calibration_weights.npy"
1613-
geography_path = artifacts / "national_geography_assignment.npz"
1617+
regional_fit_artifacts = fit_artifacts_for_scope(FitScope.REGIONAL)
1618+
national_fit_artifacts = fit_artifacts_for_scope(FitScope.NATIONAL)
1619+
weights_path = artifacts / national_fit_artifacts.weights.filename
1620+
geography_path = artifacts / national_fit_artifacts.geography.filename
16141621
db_path = artifacts / "policy_data.db"
16151622
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
1616-
config_json_path = artifacts / "national_unified_run_config.json"
1623+
config_json_path = artifacts / national_fit_artifacts.run_config.filename
16171624

16181625
required = {
16191626
"weights": weights_path,
@@ -1641,8 +1648,8 @@ def coordinate_national_publish(
16411648
config_json_path,
16421649
artifacts,
16431650
filename_remap={
1644-
"calibration_weights.npy": "national_calibration_weights.npy",
1645-
"geography_assignment.npz": "national_geography_assignment.npz",
1651+
regional_fit_artifacts.weights.filename: national_fit_artifacts.weights.filename,
1652+
regional_fit_artifacts.geography.filename: national_fit_artifacts.geography.filename,
16461653
},
16471654
)
16481655
fingerprint_inputs = _build_publishing_input_bundle(

0 commit comments

Comments
 (0)