Skip to content

Commit b3c88ec

Browse files
authored
Use national calibration preset in pipeline (#947)
1 parent 40419da commit b3c88ec

3 files changed

Lines changed: 13 additions & 5 deletions

File tree

changelog.d/945.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use the national calibration preset L0 penalty in the Modal pipeline.

modal_app/pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@
109109
or "policyengine-us-data-pipeline"
110110
)
111111

112+
NATIONAL_FIT_LAMBDA_L0 = 1e-4
113+
112114
hf_secret = modal.Secret.from_name("huggingface-token")
113115
gcp_secret = modal.Secret.from_name("gcp-credentials")
114116

@@ -1170,7 +1172,7 @@ def run_pipeline(
11701172
"epochs": national_epochs,
11711173
"target_config": "policyengine_us_data/calibration/target_config.yaml",
11721174
"beta": 0.65,
1173-
"lambda_l0": 2e-2,
1175+
"lambda_l0": NATIONAL_FIT_LAMBDA_L0,
11741176
"lambda_l2": 1e-12,
11751177
"log_freq": 100,
11761178
"skip_national": skip_national,
@@ -1257,7 +1259,7 @@ def run_pipeline(
12571259
volume_package_path=vol_path,
12581260
target_config=target_cfg,
12591261
beta=0.65,
1260-
lambda_l0=2e-2,
1262+
lambda_l0=NATIONAL_FIT_LAMBDA_L0,
12611263
lambda_l2=1e-12,
12621264
log_freq=100,
12631265
)

tests/unit/test_pipeline.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
modal = pytest.importorskip("modal")
1111

1212
from modal_app.pipeline import ( # noqa: E402
13+
NATIONAL_FIT_LAMBDA_L0,
1314
_build_diagnostics_upload_script,
1415
_calibration_package_parameters,
1516
_run_required_promotion_subprocess,
@@ -37,7 +38,7 @@ def test_calibration_package_parameters_track_matrix_mode():
3738
)
3839

3940
assert params["chunked_matrix"] is True
40-
assert params["workers"] is None
41+
assert "workers" not in params
4142
assert params["chunk_size"] == 10_000
4243
assert params["parallel_matrix"] is True
4344
assert params["num_matrix_workers"] == 25
@@ -57,9 +58,13 @@ def test_calibration_package_parameters_ignore_unused_matrix_options():
5758

5859
assert params["chunked_matrix"] is False
5960
assert params["workers"] == 50
60-
assert params["chunk_size"] is None
61+
assert "chunk_size" not in params
6162
assert params["parallel_matrix"] is False
62-
assert params["num_matrix_workers"] is None
63+
assert "num_matrix_workers" not in params
64+
65+
66+
def test_national_fit_lambda_matches_national_preset():
67+
assert NATIONAL_FIT_LAMBDA_L0 == pytest.approx(1e-4)
6368

6469

6570
class TestRunMetadata:

0 commit comments

Comments
 (0)