Skip to content

Commit 35321f9

Browse files
authored
Centralize long-run calibration weight access (#1035)
1 parent 880bb04 commit 35321f9

10 files changed

Lines changed: 174 additions & 54 deletions

File tree

.github/workflows/pr.yaml

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,20 @@ jobs:
165165
MODAL_PROXY_TOKEN_SECRET: ${{ secrets.MODAL_PROXY_TOKEN_SECRET }}
166166
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
167167
GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GOOGLE_APPLICATION_CREDENTIALS }}
168-
MODAL_ENVIRONMENT: staging-us-data-pr-${{ github.event.pull_request.number }}
169-
MODAL_APP_NAME: policyengine-us-data-pipeline
170-
MODAL_LOCAL_AREA_APP_NAME: policyengine-us-data-local-area
171-
MODAL_H5_TEST_HARNESS_APP_NAME: policyengine-us-data-h5-test-harness
168+
# Modal PR environments cannot reliably receive secrets with the CI token.
169+
# Deploy isolated PR apps and volumes into main, where required secrets
170+
# already exist, then stop/delete the PR resources in cleanup steps.
171+
MODAL_ENVIRONMENT: main
172+
MODAL_APP_NAME: us-data-pipeline-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
173+
MODAL_LOCAL_AREA_APP_NAME: us-data-local-area-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
174+
MODAL_H5_TEST_HARNESS_APP_NAME: us-data-h5-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
175+
US_DATA_PIPELINE_APP_NAME: us-data-pipeline-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
176+
US_DATA_MODAL_APP_NAME: us-data-pipeline-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
177+
US_DATA_LOCAL_AREA_APP_NAME: us-data-local-area-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
178+
US_DATA_H5_HARNESS_APP_NAME: us-data-h5-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
179+
US_DATA_PIPELINE_VOLUME_NAME: pipeline-artifacts-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
180+
US_DATA_STAGING_VOLUME_NAME: local-area-staging-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
181+
US_DATA_CHECKPOINT_VOLUME_NAME: data-build-checkpoints-pr-${{ github.event.pull_request.number }}-${{ github.run_id }}-${{ github.run_attempt }}
172182
steps:
173183
- uses: actions/checkout@v6
174184
- uses: actions/setup-python@v6
@@ -178,15 +188,11 @@ jobs:
178188
- run: uv sync --dev
179189
- name: Install integration test deps
180190
run: uv pip install modal pytest numpy pandas
181-
- name: Ensure PR Modal environment exists
182-
run: uv run python .github/scripts/ensure_modal_environment.py
183-
- name: Sync Modal secrets to PR environment
184-
run: uv run python .github/scripts/sync_modal_secrets.py
185-
- name: Deploy Modal pipeline app to PR staging
191+
- name: Deploy PR Modal pipeline app
186192
run: uv run modal deploy --env="${MODAL_ENVIRONMENT}" modal_app/pipeline.py
187-
- name: Deploy Modal local-area app to PR staging
193+
- name: Deploy PR Modal local-area app
188194
run: uv run modal deploy --env="${MODAL_ENVIRONMENT}" modal_app/local_area.py
189-
- name: Deploy Modal H5 test harness to PR staging
195+
- name: Deploy PR Modal H5 test harness
190196
run: uv run modal deploy --env="${MODAL_ENVIRONMENT}" modal_app/h5_test_harness.py
191197
- name: Run integration tests
192198
run: >
@@ -204,9 +210,32 @@ jobs:
204210
tests/integration/test_tiny_h5_pipeline.py
205211
tests/integration/test_modal_pipeline_e2e.py
206212
-v
207-
- name: Cleanup PR Modal environment
213+
- name: Stop PR Modal apps
208214
if: always()
209-
run: uv run python .github/scripts/delete_modal_environment.py
215+
run: |
216+
for app_name in \
217+
"${MODAL_H5_TEST_HARNESS_APP_NAME}" \
218+
"${MODAL_LOCAL_AREA_APP_NAME}" \
219+
"${MODAL_APP_NAME}"
220+
do
221+
yes | uv run modal app stop \
222+
--env="${MODAL_ENVIRONMENT}" \
223+
"${app_name}" || true
224+
done
225+
- name: Delete PR Modal volumes
226+
if: always()
227+
run: |
228+
for volume_name in \
229+
"${US_DATA_STAGING_VOLUME_NAME}" \
230+
"${US_DATA_PIPELINE_VOLUME_NAME}" \
231+
"${US_DATA_CHECKPOINT_VOLUME_NAME}"
232+
do
233+
uv run modal volume delete \
234+
--env="${MODAL_ENVIRONMENT}" \
235+
--allow-missing \
236+
--yes \
237+
"${volume_name}" || true
238+
done
210239
211240
smoke-test:
212241
runs-on: ubuntu-latest

changelog.d/1033.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Centralized long-run calibration weight access so baseline diagnostics use PolicyEngine weighted operations, hardened PR Modal integration isolation, and retried the Census county lookup used by local-area H5 builds.

policyengine_us_data/calibration/block_assignment.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323

2424
import random
25+
import time
2526
import unicodedata
2627
from functools import lru_cache
2728
from io import StringIO
@@ -63,6 +64,35 @@ def get_tract_geoid_from_block(block_geoid: str) -> str:
6364

6465
# === County FIPS to Enum Mapping ===
6566

67+
COUNTY_FIPS_2020_URL = (
68+
"https://www2.census.gov/geo/docs/reference/codes2020/national_county2020.txt"
69+
)
70+
COUNTY_FIPS_DOWNLOAD_ATTEMPTS = 5
71+
COUNTY_FIPS_RETRY_BACKOFF_SECONDS = 1.0
72+
73+
74+
def _county_fips_session() -> requests.Session:
75+
return requests.Session()
76+
77+
78+
def _download_county_fips_2020(
79+
session: requests.Session | None = None,
80+
) -> str:
81+
session = session or _county_fips_session()
82+
last_exception = None
83+
for attempt in range(COUNTY_FIPS_DOWNLOAD_ATTEMPTS):
84+
try:
85+
response = session.get(COUNTY_FIPS_2020_URL, timeout=(10, 60))
86+
response.raise_for_status()
87+
return response.content.decode("utf-8")
88+
except requests.RequestException as exc:
89+
last_exception = exc
90+
if attempt == COUNTY_FIPS_DOWNLOAD_ATTEMPTS - 1:
91+
raise
92+
time.sleep(COUNTY_FIPS_RETRY_BACKOFF_SECONDS * (2**attempt))
93+
94+
raise RuntimeError("Failed to download 2020 county FIPS data") from last_exception
95+
6696

6797
@lru_cache(maxsize=1)
6898
def _build_county_fips_to_enum() -> Dict[str, str]:
@@ -72,11 +102,8 @@ def _build_county_fips_to_enum() -> Dict[str, str]:
72102
Downloads Census county FIPS file and matches to County enum names.
73103
Cached to avoid repeated downloads.
74104
"""
75-
url = "https://www2.census.gov/geo/docs/reference/codes2020/national_county2020.txt"
76-
response = requests.get(url, timeout=60)
77-
response.raise_for_status()
78105
df = pd.read_csv(
79-
StringIO(response.content.decode("utf-8")),
106+
StringIO(_download_county_fips_2020()),
80107
delimiter="|",
81108
dtype=str,
82109
usecols=["STATE", "STATEFP", "COUNTYFP", "COUNTYNAME"],

policyengine_us_data/datasets/cps/long_term/assess_publishable_horizon.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
aggregate_household_age_matrix,
2424
build_age_bins,
2525
build_household_age_matrix,
26+
household_calibration_weights,
2627
)
2728
from ssa_data import (
2829
get_long_term_target_source,
@@ -145,9 +146,9 @@ def assess_years(
145146
target_matrix = load_ssa_age_projections(start_year=start_year, end_year=end_year)
146147
n_ages = target_matrix.shape[0]
147148

148-
sim = Microsimulation(dataset=base_dataset_path)
149-
X, _, _ = build_household_age_matrix(sim, n_ages)
150-
del sim
149+
base_sim = Microsimulation(dataset=base_dataset_path)
150+
X, _, _ = build_household_age_matrix(base_sim, n_ages)
151+
del base_sim
151152
gc.collect()
152153

153154
aggregated_age_cache: dict[int, tuple[np.ndarray, np.ndarray]] = {}
@@ -158,8 +159,7 @@ def assess_years(
158159
year_idx = year - start_year
159160
sim = Microsimulation(dataset=base_dataset_path)
160161

161-
household_microseries = sim.calculate("household_id", map_to="household")
162-
baseline_weights = household_microseries.weights.values
162+
baseline_weights = household_calibration_weights(sim)
163163

164164
ss_values = None
165165
ss_target = None
@@ -294,7 +294,7 @@ def assess_years(
294294
best_case_match.group(2)
295295
)
296296
rows.append(row)
297-
del sim
297+
sim = None
298298
gc.collect()
299299
continue
300300

@@ -375,7 +375,7 @@ def assess_years(
375375

376376
rows.append(row)
377377

378-
del sim
378+
sim = None
379379
gc.collect()
380380

381381
return rows

policyengine_us_data/datasets/cps/long_term/check_calibrated_estimates_interactive.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from policyengine_core.reforms import Reform
34
from policyengine_us import Microsimulation
45

56
# H5_PATH = 'hf://policyengine/test/'
@@ -32,11 +33,8 @@
3233
## Population demographics, total
3334

3435
### Population count of 6 year olds
35-
person_weights = sim.calculate("age", map_to="person").weights
36-
person_ages = sim.calculate("age", map_to="person").values
37-
person_is_6 = person_ages == 6
38-
39-
total_age6_est = np.sum(person_is_6 * person_weights)
36+
person_ages = sim.calculate("age", map_to="person")
37+
total_age6_est = (person_ages == 6).sum()
4038

4139
### Single Year Age demographic projections - latest published is 2024:
4240
### "Mid Year" CSV from https://www.ssa.gov/oact/HistEst/Population/2024/Population2024.html
@@ -73,11 +71,8 @@
7371
## Population demographics, total
7472

7573
### Population count of 6 year olds
76-
person_weights = sim.calculate("age", map_to="person").weights
77-
person_ages = sim.calculate("age", map_to="person").values
78-
person_is_6 = person_ages == 6
79-
80-
total_age6_est = np.sum(person_is_6 * person_weights)
74+
person_ages = sim.calculate("age", map_to="person")
75+
total_age6_est = (person_ages == 6).sum()
8176

8277
### Single Year Age demographic projections - latest published is 2024:
8378
### "Mid Year" CSV from https://www.ssa.gov/oact/HistEst/Population/2024/Population2024.html
@@ -101,9 +96,6 @@
10196

10297
# Testing the H6 Reform ------------------------------------------------------
10398

104-
from policyengine_us import Microsimulation
105-
from policyengine_core.reforms import Reform
106-
10799

108100
def create_h6_reform():
109101
"""

policyengine_us_data/datasets/cps/long_term/evaluate_support_augmentation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import argparse
44
import json
5-
from pathlib import Path
65

76
import numpy as np
87
from policyengine_us import Microsimulation
@@ -17,6 +16,7 @@
1716
aggregate_household_age_matrix,
1817
build_age_bins,
1918
build_household_age_matrix,
19+
household_calibration_weights,
2020
)
2121
from ssa_data import (
2222
get_long_term_target_source,
@@ -60,8 +60,7 @@ def _evaluate_dataset(
6060
y_target = target_matrix[:, 0]
6161
age_bucket_size = 1
6262

63-
household_series = sim.calculate("household_id", period=year, map_to="household")
64-
baseline_weights = household_series.weights.values
63+
baseline_weights = household_calibration_weights(sim, period=year)
6564

6665
ss_values = None
6766
ss_target = None

policyengine_us_data/datasets/cps/long_term/projection_utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,26 @@ def _row_values(series):
3232
return np.asarray(series)
3333

3434

35+
def household_calibration_weights(sim, *, period=None) -> np.ndarray:
36+
"""
37+
Return household weights for calibration decision vectors only.
38+
39+
Ordinary weighted totals should use MicroSeries/MicroDataFrame methods such
40+
as ``sum()`` so PolicyEngine owns the entity-to-weight mapping. The long-run
41+
calibration optimizer is the exception: it needs the household-level weight
42+
vector because it directly solves for adjusted household weights.
43+
"""
44+
if period is None:
45+
household_series = sim.calculate("household_id", map_to="household")
46+
else:
47+
household_series = sim.calculate(
48+
"household_id",
49+
period=period,
50+
map_to="household",
51+
)
52+
return np.asarray(household_series.weights, dtype=float)
53+
54+
3555
def _person_level_values(sim, variable, *, period):
3656
try:
3757
series = sim.calculate(variable, period=period, map_to="person")
@@ -426,10 +446,7 @@ def calculate_year_statistics(
426446
income_tax_baseline_total = income_tax_hh.sum()
427447
income_tax_values = income_tax_hh.values
428448

429-
household_microseries = sim.calculate("household_id", map_to="household")
430-
# Explicit weight access is reserved for the household-level calibration
431-
# decision vector; ordinary aggregates should use MicroSeries methods.
432-
baseline_weights_actual = household_microseries.weights.values
449+
baseline_weights_actual = household_calibration_weights(sim)
433450

434451
ss_values = None
435452
ss_target = None

policyengine_us_data/datasets/cps/long_term/run_household_projection.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
build_age_bins,
8484
build_household_age_matrix,
8585
create_household_year_h5,
86+
household_calibration_weights,
8687
validate_projected_social_security_cap,
8788
)
8889
from tax_assumptions import (
@@ -1115,25 +1116,24 @@ def _print_support_augmentation_summary(augmentation_report: dict) -> None:
11151116
income_tax_values = income_tax_hh.values
11161117

11171118
household_microseries = sim.calculate("household_id", map_to="household")
1118-
# This is the calibrated household-weight decision vector. All ordinary
1119-
# baseline aggregates should continue to use MicroSeries methods directly.
1120-
baseline_weights = household_microseries.weights.values
1121-
household_ids_hh = household_microseries.values
1119+
baseline_weights = household_calibration_weights(sim)
1120+
household_ids_hh = np.asarray(household_microseries.array)
11221121

11231122
income_guard_constraints = {}
11241123
if year >= SUPPORT_AUGMENTATION_START_YEAR:
11251124
for group_name, components in INCOME_GUARD_GROUPS.items():
11261125
group_values = np.zeros(len(baseline_weights), dtype=float)
1126+
group_target = 0.0
11271127
included_components = []
11281128
for component in components:
11291129
if component not in sim.tax_benefit_system.variables:
11301130
continue
11311131
component_hh = sim.calculate(component, period=year, map_to="household")
11321132
group_values += np.asarray(component_hh.values, dtype=float)
1133+
group_target += float(component_hh.sum())
11331134
included_components.append(component)
11341135
if not included_components:
11351136
continue
1136-
group_target = float(np.sum(group_values * baseline_weights))
11371137
if abs(group_target) <= 1e-6:
11381138
continue
11391139
income_guard_constraints[f"income_guard_{group_name}"] = (
@@ -1164,7 +1164,7 @@ def _print_support_augmentation_summary(augmentation_report: dict) -> None:
11641164
ss_values = ss_hh.values
11651165
ss_target = load_ssa_benefit_projections(year)
11661166
if year in display_years:
1167-
ss_baseline = np.sum(ss_values * baseline_weights)
1167+
ss_baseline = ss_hh.sum()
11681168
print(
11691169
f" [DEBUG {year}] SS baseline: ${ss_baseline / 1e9:.1f}B, target: ${ss_target / 1e9:.1f}B"
11701170
)
@@ -1190,7 +1190,7 @@ def _print_support_augmentation_summary(augmentation_report: dict) -> None:
11901190
payroll_values = taxable_wages_hh.values + taxable_self_emp_hh.values
11911191
payroll_target = load_taxable_payroll_projections(year)
11921192
if year in display_years:
1193-
payroll_baseline = np.sum(payroll_values * baseline_weights)
1193+
payroll_baseline = taxable_wages_hh.sum() + taxable_self_emp_hh.sum()
11941194
print(f" [DEBUG {year}] Payroll cap: ${payroll_cap:,.0f}")
11951195
print(
11961196
f" [DEBUG {year}] Payroll baseline: ${payroll_baseline / 1e9:.1f}B, target: ${payroll_target / 1e9:.1f}B"
@@ -1231,7 +1231,7 @@ def _print_support_augmentation_summary(augmentation_report: dict) -> None:
12311231

12321232
# Debug output for key years
12331233
if year in display_years:
1234-
h6_impact_baseline = np.sum(h6_income_values * baseline_weights)
1234+
h6_impact_baseline = income_tax_reform_hh.sum() - income_tax_hh.sum()
12351235
print(
12361236
f" [DEBUG {year}] H6 baseline revenue: ${h6_impact_baseline / 1e9:.3f}B, target: ${h6_revenue_target / 1e9:.3f}B"
12371237
)
@@ -1260,8 +1260,8 @@ def _print_support_augmentation_summary(augmentation_report: dict) -> None:
12601260
hi_tob_target = load_hi_tob_projections(year)
12611261

12621262
if year in display_years:
1263-
oasdi_baseline = np.sum(oasdi_tob_values * baseline_weights)
1264-
hi_baseline = np.sum(hi_tob_values * baseline_weights)
1263+
oasdi_baseline = oasdi_tob_hh.sum()
1264+
hi_baseline = hi_tob_hh.sum()
12651265
print(
12661266
f" [DEBUG {year}] OASDI TOB baseline: ${oasdi_baseline / 1e9:.1f}B, target: ${oasdi_tob_target / 1e9:.1f}B"
12671267
)

0 commit comments

Comments
 (0)