Skip to content

Commit 77b13d3

Browse files
baogorekclaudeMaxGhenis
authored
Fix h5 files by saving calibration geography artifact, and model fit resume function (#708)
* Save geography artifacts and add calibration resume/checkpoint support Calibration now persists geography_assignment.npz alongside weights so that downstream publish and worker steps use the exact same geography instead of regenerating it randomly. Adds --resume-from and --checkpoint-output flags to unified_calibration for continuing fits from a saved checkpoint or warm-starting from weights. Also gitignores *.csv.gz to prevent accidental commits of cached ORG data. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Run ruff format Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add changelog fragment for PR 708 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Stub l0 module in test so patch works without l0-python installed Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add self-employment and SSN card type count targets to calibration config Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Distinguish ITIN holders from SSN holders in CPS data Fix calibration crash on string constraint variables (ssn_card_type) by falling back from float32 cast when values are non-numeric. Impute ITIN status for undocumented (code-0) persons: select tax units with code-0 earners via weighted random sampling targeting 4.4M ITIN returns (IRS NTA), then mark all code-0 members of those units. Updates has_tin = (ssn_card_type != 0) | has_itin_number so ITIN holders correctly qualify for ODC ($500 credit). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix PR 708 checkpoint and ID regressions * Handle string ID fields in PUF cloning * Fold taxpayer ID imputation into calibration resume PR * Fix PUF subsample logging format * fixes * Expand PR 708 changelog summary --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Max Ghenis <mghenis@gmail.com>
1 parent fd8cb99 commit 77b13d3

26 files changed

Lines changed: 2209 additions & 208 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
**/*.h5
55
**/*.npy
66
**/*.csv
7+
**/*.csv.gz
78
**/_build
89
**/*.pkl
910
**/*.db

changelog.d/708.added

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Save calibration geography as a pipeline artifact, add ``--resume-from`` and checkpoint support for long-running calibration fits, and fix resume/artifact handling in the remote calibration pipeline. This also adds conservative CPS taxpayer-ID outputs (``has_tin``, ``has_valid_ssn``, and a temporary ``has_itin`` compatibility alias), plus string-valued constraint handling needed for ID-target calibration.

docs/calibration.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ python -m policyengine_us_data.calibration.unified_calibration \
2323
--package-path storage/calibration/calibration_package.pkl \
2424
--epochs 500 --device cuda
2525

26+
# Resume a previous fit for 500 more epochs:
27+
python -m policyengine_us_data.calibration.unified_calibration \
28+
--package-path storage/calibration/calibration_package.pkl \
29+
--resume-from storage/calibration/calibration_weights.npy \
30+
--epochs 500 --device cuda
31+
2632
# Full pipeline with PUF (build + fit in one shot):
2733
make calibrate
2834
```
@@ -88,6 +94,30 @@ python -m policyengine_us_data.calibration.unified_calibration \
8894
You can re-run Step 2 as many times as you want with different hyperparameters. The expensive matrix
8995
build only happens once.
9096

97+
Every fit now also writes a checkpoint next to the weights output
98+
(`calibration_weights.checkpoint.pt` by default). To continue the same fit,
99+
pass `--resume-from` with the weights file or checkpoint path. If a sibling
100+
checkpoint exists next to the weights file, it is used automatically so the
101+
L0 gate state is restored as well.
102+
103+
```bash
104+
python -m policyengine_us_data.calibration.unified_calibration \
105+
--package-path storage/calibration/calibration_package.pkl \
106+
--epochs 2000 \
107+
--beta 0.65 \
108+
--lambda-l0 1e-4 \
109+
--lambda-l2 1e-12 \
110+
--log-freq 500 \
111+
--target-config policyengine_us_data/calibration/target_config.yaml \
112+
--device cpu \
113+
--output policyengine_us_data/storage/calibration/national/weights.npy \
114+
--resume-from policyengine_us_data/storage/calibration/national/weights.npy
115+
```
116+
117+
When `--resume-from` points to a checkpoint, `--epochs` means additional epochs
118+
to run beyond the saved checkpoint epoch count. If only a `.npy` weights file
119+
exists, the run warm-starts from those weights.
120+
91121
### 2. Full pipeline with PUF
92122

93123
Adding `--puf-dataset` doubles the record count (~24K base records x 430 clones = ~10.3M columns) by

modal_app/local_area.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ def build_areas_worker(
334334
"--output-dir",
335335
str(output_dir),
336336
]
337+
if "geography" in calibration_inputs:
338+
worker_cmd.extend(["--geography-path", calibration_inputs["geography"]])
337339
if "n_clones" in calibration_inputs:
338340
worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])])
339341
if "seed" in calibration_inputs:
@@ -659,6 +661,7 @@ def coordinate_publish(
659661
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
660662
)
661663
weights_path = artifacts / "calibration_weights.npy"
664+
geography_path = artifacts / "geography_assignment.npz"
662665
db_path = artifacts / "policy_data.db"
663666
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
664667
config_json_path = artifacts / "unified_run_config.json"
@@ -678,6 +681,7 @@ def coordinate_publish(
678681

679682
calibration_inputs = {
680683
"weights": str(weights_path),
684+
"geography": str(geography_path),
681685
"dataset": str(dataset_path),
682686
"database": str(db_path),
683687
"n_clones": n_clones,
@@ -943,6 +947,7 @@ def coordinate_national_publish(
943947
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
944948
)
945949
weights_path = artifacts / "national_calibration_weights.npy"
950+
geography_path = artifacts / "national_geography_assignment.npz"
946951
db_path = artifacts / "policy_data.db"
947952
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
948953
config_json_path = artifacts / "national_unified_run_config.json"
@@ -962,6 +967,7 @@ def coordinate_national_publish(
962967

963968
calibration_inputs = {
964969
"weights": str(weights_path),
970+
"geography": str(geography_path),
965971
"dataset": str(dataset_path),
966972
"database": str(db_path),
967973
"n_clones": n_clones,
@@ -972,6 +978,7 @@ def coordinate_national_publish(
972978
artifacts,
973979
filename_remap={
974980
"calibration_weights.npy": "national_calibration_weights.npy",
981+
"geography_assignment.npz": "national_geography_assignment.npz",
975982
},
976983
)
977984
run_dir = staging_dir / run_id

modal_app/pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,11 @@ def run_pipeline(
832832
BytesIO(regional_result["weights"]),
833833
f"{artifacts_rel}/calibration_weights.npy",
834834
)
835+
if regional_result.get("geography"):
836+
batch.put_file(
837+
BytesIO(regional_result["geography"]),
838+
f"{artifacts_rel}/geography_assignment.npz",
839+
)
835840
if regional_result.get("config"):
836841
batch.put_file(
837842
BytesIO(regional_result["config"]),
@@ -856,6 +861,11 @@ def run_pipeline(
856861
BytesIO(national_result["weights"]),
857862
f"{artifacts_rel}/national_calibration_weights.npy",
858863
)
864+
if national_result.get("geography"):
865+
batch.put_file(
866+
BytesIO(national_result["geography"]),
867+
f"{artifacts_rel}/national_geography_assignment.npz",
868+
)
859869
if national_result.get("config"):
860870
batch.put_file(
861871
BytesIO(national_result["config"]),

0 commit comments

Comments
 (0)