Skip to content

Commit a8eee33

Browse files
committed
Enforce Stage 2 contract before fitting weights
1 parent 4544761 commit a8eee33

10 files changed

Lines changed: 496 additions & 9 deletions

File tree

changelog.d/1116.changed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Require Stage 3 fitted-weight runs to verify the Stage 2 calibration package contract before fitting.

docs/engineering/stages/fit_weights.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ builds. The public identity boundary lives in `policyengine_us_data.fit_weights`
1212
`FittedWeightsOutputBundle` keep Stage 3 package inputs and remote result
1313
bytes typed before they become files.
1414

15+
Normal pipeline runs must fit from a Stage 2 package that has a matching
16+
`calibration_package_contract.json` sidecar. `FittedWeightsInputBundle` reads
17+
that contract before GPU fitting starts, checks the contract-declared
18+
`calibration_package.pkl` checksum and size against the package on the pipeline
19+
volume, and records both the package checksum and contract checksum in the fit
20+
step parameters. Manual legacy package runs may proceed without the contract
21+
only through the explicit no-contract fallback, which emits a warning and
22+
records that only the package checksum was available.
23+
1524
The current artifact names remain behavior-compatible:
1625

1726
- regional: `calibration_weights.npy`, `geography_assignment.npz`,

modal_app/pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,7 @@ def run_pipeline(
15341534
scope=FitScope.REGIONAL,
15351535
calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl",
15361536
)
1537+
fit_stage2_identity = regional_fit_input.stage2_identity_parameters()
15371538
fit_inputs = _artifact_identities(regional_fit_input.artifact_identity_paths())
15381539
regional_fit_spec = fitted_weights_spec_for_scope(FitScope.REGIONAL)
15391540
national_fit_spec = fitted_weights_spec_for_scope(FitScope.NATIONAL)
@@ -1542,11 +1543,12 @@ def run_pipeline(
15421543
regional_fit_parameters = regional_fit_spec.manifest_parameters(
15431544
gpu=gpu,
15441545
epochs=epochs,
1546+
extra=fit_stage2_identity,
15451547
)
15461548
national_fit_parameters = national_fit_spec.manifest_parameters(
15471549
gpu=national_gpu,
15481550
epochs=national_epochs,
1549-
extra={"skip_national": skip_national},
1551+
extra={**fit_stage2_identity, "skip_national": skip_national},
15501552
)
15511553
regional_fit_reuse = _step_reusable(
15521554
meta,
@@ -1587,6 +1589,9 @@ def run_pipeline(
15871589
step_start = time.time()
15881590

15891591
vol_path = f"{artifacts_dir_for_run(run_id)}/calibration_package.pkl"
1592+
vol_contract_path = str(
1593+
regional_fit_input.calibration_package_contract_path
1594+
)
15901595

15911596
# Spawn regional fit
15921597
regional_func = PACKAGE_GPU_FUNCTIONS[gpu]
@@ -1595,6 +1600,8 @@ def run_pipeline(
15951600
branch=branch,
15961601
epochs=epochs,
15971602
volume_package_path=vol_path,
1603+
volume_package_contract_path=vol_contract_path,
1604+
fit_scope=FitScope.REGIONAL.value,
15981605
**regional_fit_spec.runtime_kwargs(),
15991606
)
16001607
print(f" → regional fit fc: {regional_handle.object_id}")
@@ -1623,6 +1630,8 @@ def run_pipeline(
16231630
branch=branch,
16241631
epochs=national_epochs,
16251632
volume_package_path=vol_path,
1633+
volume_package_contract_path=vol_contract_path,
1634+
fit_scope=FitScope.NATIONAL.value,
16261635
**national_fit_spec.runtime_kwargs(),
16271636
)
16281637
print(f" → national fit fc: {national_handle.object_id}")

modal_app/remote_calibration_runner.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
from modal_app.images import gpu_image as image # noqa: E402
1515
from policyengine_us_data.calibration_package.specs import ( # noqa: E402
16+
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
1617
calibration_package_artifact_paths,
1718
stage2_build_context_for_run,
1819
)
1920
from policyengine_us_data.fit_weights import ( # noqa: E402
2021
FitResultBytes,
2122
FitScope,
23+
FittedWeightsInputBundle,
2224
NATIONAL_FIT_LAMBDA_L0,
2325
fit_artifacts_for_scope,
2426
)
@@ -288,6 +290,9 @@ def _fit_from_package_impl(
288290
branch: str,
289291
epochs: int,
290292
volume_package_path: str = None,
293+
volume_package_contract_path: str = None,
294+
allow_legacy_no_contract: bool = False,
295+
fit_scope: str = FitScope.REGIONAL.value,
291296
target_config: str = None,
292297
beta: float = None,
293298
lambda_l0: float = None,
@@ -300,6 +305,21 @@ def _fit_from_package_impl(
300305
raise ValueError("volume_package_path is required")
301306

302307
_setup_repo()
308+
input_bundle = FittedWeightsInputBundle(
309+
scope=fit_scope,
310+
calibration_package_path=Path(volume_package_path),
311+
calibration_package_contract_path=(
312+
Path(volume_package_contract_path) if volume_package_contract_path else None
313+
),
314+
allow_legacy_no_contract=allow_legacy_no_contract,
315+
)
316+
stage2_identity = input_bundle.stage2_identity()
317+
if stage2_identity.stage2_contract_mode == "stage2_contract":
318+
print(
319+
"Validated Stage 2 calibration package contract "
320+
f"{stage2_identity.calibration_package_contract_fingerprint}",
321+
flush=True,
322+
)
303323

304324
pkg_path = "/root/calibration_package.pkl"
305325
import shutil
@@ -816,11 +836,17 @@ def fit_from_package_t4(
816836
learning_rate: float = None,
817837
log_freq: int = None,
818838
volume_package_path: str = None,
839+
volume_package_contract_path: str = None,
840+
allow_legacy_no_contract: bool = False,
841+
fit_scope: str = FitScope.REGIONAL.value,
819842
) -> dict:
820843
return _fit_from_package_impl(
821844
branch,
822845
epochs,
823846
volume_package_path=volume_package_path,
847+
volume_package_contract_path=volume_package_contract_path,
848+
allow_legacy_no_contract=allow_legacy_no_contract,
849+
fit_scope=fit_scope,
824850
target_config=target_config,
825851
beta=beta,
826852
lambda_l0=lambda_l0,
@@ -848,11 +874,17 @@ def fit_from_package_a10(
848874
learning_rate: float = None,
849875
log_freq: int = None,
850876
volume_package_path: str = None,
877+
volume_package_contract_path: str = None,
878+
allow_legacy_no_contract: bool = False,
879+
fit_scope: str = FitScope.REGIONAL.value,
851880
) -> dict:
852881
return _fit_from_package_impl(
853882
branch,
854883
epochs,
855884
volume_package_path=volume_package_path,
885+
volume_package_contract_path=volume_package_contract_path,
886+
allow_legacy_no_contract=allow_legacy_no_contract,
887+
fit_scope=fit_scope,
856888
target_config=target_config,
857889
beta=beta,
858890
lambda_l0=lambda_l0,
@@ -880,11 +912,17 @@ def fit_from_package_a100_40(
880912
learning_rate: float = None,
881913
log_freq: int = None,
882914
volume_package_path: str = None,
915+
volume_package_contract_path: str = None,
916+
allow_legacy_no_contract: bool = False,
917+
fit_scope: str = FitScope.REGIONAL.value,
883918
) -> dict:
884919
return _fit_from_package_impl(
885920
branch,
886921
epochs,
887922
volume_package_path=volume_package_path,
923+
volume_package_contract_path=volume_package_contract_path,
924+
allow_legacy_no_contract=allow_legacy_no_contract,
925+
fit_scope=fit_scope,
888926
target_config=target_config,
889927
beta=beta,
890928
lambda_l0=lambda_l0,
@@ -912,11 +950,17 @@ def fit_from_package_a100_80(
912950
learning_rate: float = None,
913951
log_freq: int = None,
914952
volume_package_path: str = None,
953+
volume_package_contract_path: str = None,
954+
allow_legacy_no_contract: bool = False,
955+
fit_scope: str = FitScope.REGIONAL.value,
915956
) -> dict:
916957
return _fit_from_package_impl(
917958
branch,
918959
epochs,
919960
volume_package_path=volume_package_path,
961+
volume_package_contract_path=volume_package_contract_path,
962+
allow_legacy_no_contract=allow_legacy_no_contract,
963+
fit_scope=fit_scope,
920964
target_config=target_config,
921965
beta=beta,
922966
lambda_l0=lambda_l0,
@@ -944,11 +988,17 @@ def fit_from_package_h100(
944988
learning_rate: float = None,
945989
log_freq: int = None,
946990
volume_package_path: str = None,
991+
volume_package_contract_path: str = None,
992+
allow_legacy_no_contract: bool = False,
993+
fit_scope: str = FitScope.REGIONAL.value,
947994
) -> dict:
948995
return _fit_from_package_impl(
949996
branch,
950997
epochs,
951998
volume_package_path=volume_package_path,
999+
volume_package_contract_path=volume_package_contract_path,
1000+
allow_legacy_no_contract=allow_legacy_no_contract,
1001+
fit_scope=fit_scope,
9521002
target_config=target_config,
9531003
beta=beta,
9541004
lambda_l0=lambda_l0,
@@ -1008,12 +1058,23 @@ def main(
10081058

10091059
if package_path:
10101060
vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl"
1061+
local_contract_path = Path(package_path).with_name(
1062+
CALIBRATION_PACKAGE_CONTRACT_FILENAME
1063+
)
1064+
vol_contract_path = (
1065+
f"{PIPELINE_MOUNT}/artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}"
1066+
if local_contract_path.exists()
1067+
else None
1068+
)
10111069
print(f"Reading package from {package_path}...", flush=True)
10121070
import json as _json
10131071
import pickle as _pkl
10141072

10151073
with open(package_path, "rb") as f:
10161074
package_bytes = f.read()
1075+
contract_bytes = (
1076+
local_contract_path.read_bytes() if local_contract_path.exists() else None
1077+
)
10171078
size = len(package_bytes)
10181079
pkg_meta = _pkl.loads(package_bytes).get("metadata", {})
10191080
sidecar_bytes = _json.dumps(pkg_meta, indent=2).encode()
@@ -1032,6 +1093,11 @@ def main(
10321093
BytesIO(sidecar_bytes),
10331094
"artifacts/calibration_package_meta.json",
10341095
)
1096+
if contract_bytes is not None:
1097+
batch.put_file(
1098+
BytesIO(contract_bytes),
1099+
f"artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}",
1100+
)
10351101
pipeline_vol.commit()
10361102
del package_bytes
10371103
print("Upload complete.", flush=True)
@@ -1047,6 +1113,9 @@ def main(
10471113
learning_rate=learning_rate,
10481114
log_freq=log_freq,
10491115
volume_package_path=vol_path,
1116+
volume_package_contract_path=vol_contract_path,
1117+
allow_legacy_no_contract=True,
1118+
fit_scope=scope.value,
10501119
)
10511120
elif full_pipeline:
10521121
print(
@@ -1080,6 +1149,9 @@ def main(
10801149
)
10811150
else:
10821151
vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl"
1152+
vol_contract_path = (
1153+
f"{PIPELINE_MOUNT}/artifacts/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}"
1154+
)
10831155
vol_info = check_volume_package.remote()
10841156
if not vol_info["exists"]:
10851157
raise SystemExit(
@@ -1134,6 +1206,9 @@ def main(
11341206
learning_rate=learning_rate,
11351207
log_freq=log_freq,
11361208
volume_package_path=vol_path,
1209+
volume_package_contract_path=vol_contract_path,
1210+
allow_legacy_no_contract=True,
1211+
fit_scope=scope.value,
11371212
)
11381213

11391214
with open(output, "wb") as f:

policyengine_us_data/fit_weights/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from policyengine_us_data.fit_weights.bundles import (
1111
FitResultBytes,
1212
FitWeightsBuildContext,
13+
FittedWeightsInputContractError,
1314
FittedWeightsInputBundle,
15+
FittedWeightsInputIdentity,
1416
FittedWeightsOutputBundle,
1517
MissingFitWeightsOutputError,
1618
)
@@ -45,7 +47,9 @@
4547
"FitResultBytes",
4648
"FitScope",
4749
"FitWeightsBuildContext",
50+
"FittedWeightsInputContractError",
4851
"FittedWeightsInputBundle",
52+
"FittedWeightsInputIdentity",
4953
"FittedWeightsOutputBundle",
5054
"FittedWeightsSpec",
5155
"MissingFitWeightsOutputError",

0 commit comments

Comments
 (0)