Skip to content

Commit 7160996

Browse files
authored
Add BEA state wage calibration targets (#1034)
* Add BEA state wage calibration targets * Bump policyengine-us to 1.696.0 * Prioritize BEA wage controls in legacy reweighting * Format BEA wage loss weights * Relax Modal status seam timeout * Bump policyengine-us to 1.697.0 * Retry Modal PR secret sync
1 parent 544b30f commit 7160996

19 files changed

Lines changed: 673 additions & 16 deletions

.github/scripts/sync_modal_secrets.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
11
import os
22
import subprocess
3+
import time
4+
5+
6+
def create_secret_with_retry(args: list[str]) -> None:
7+
max_attempts = 5
8+
for attempt in range(1, max_attempts + 1):
9+
try:
10+
subprocess.run(args, check=True)
11+
return
12+
except subprocess.CalledProcessError:
13+
if attempt == max_attempts:
14+
raise
15+
delay = min(2**attempt, 10)
16+
print(
17+
"Modal secret creation failed; retrying "
18+
f"in {delay}s ({attempt}/{max_attempts})"
19+
)
20+
time.sleep(delay)
321

422

523
def main() -> None:
624
env_name = os.environ["MODAL_ENVIRONMENT"]
7-
subprocess.run(
25+
create_secret_with_retry(
826
[
927
"modal",
1028
"secret",
@@ -15,9 +33,8 @@ def main() -> None:
1533
"huggingface-token",
1634
f"HUGGING_FACE_TOKEN={os.environ['HUGGING_FACE_TOKEN']}",
1735
],
18-
check=True,
1936
)
20-
subprocess.run(
37+
create_secret_with_retry(
2138
[
2239
"modal",
2340
"secret",
@@ -31,7 +48,6 @@ def main() -> None:
3148
f"{os.environ['GOOGLE_APPLICATION_CREDENTIALS']}"
3249
),
3350
],
34-
check=True,
3551
)
3652

3753

AGENTS.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ Manually sourced national or local-file calibration targets must be registered
2929
in every active target path before merging:
3030

3131
1. `policyengine_us_data/utils/loss.py` for the ECPS loss matrix.
32-
2. `policyengine_us_data/db/etl_national_targets.py` for `policy_data.db` and
33-
local H5 validation inputs.
32+
2. The appropriate `policyengine_us_data/db/etl_*.py` loader for
33+
`policy_data.db` and local H5 validation inputs. National targets usually
34+
belong in `etl_national_targets.py`; state or local targets should use a
35+
state/local ETL module and must still be added to this DB path.
3436
3. `policyengine_us_data/calibration/target_config.yaml` when the default
3537
calibration uses an `include:` list; otherwise the target can exist in
3638
`policy_data.db` but still be omitted from calibration.

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ database:
8484
python -m policyengine_us_data.db.create_database_tables
8585
python -m policyengine_us_data.db.create_initial_strata --year $(YEAR)
8686
python -m policyengine_us_data.db.etl_national_targets --year $(YEAR)
87+
python -m policyengine_us_data.db.etl_bea_state_wages --year $(YEAR)
8788
python -m policyengine_us_data.db.etl_age --year $(YEAR)
8889
python -m policyengine_us_data.db.etl_medicaid --year $(YEAR)
8990
python -m policyengine_us_data.db.etl_snap --year $(YEAR)

changelog.d/1033.added

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add BEA regional state wage targets to calibration.

docs/engineering/skills/calibration_targets.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ calibration targets.
88
New targets must be registered in both active target systems:
99

1010
- `policyengine_us_data/utils/loss.py` for the ECPS `build_loss_matrix()` path.
11-
- `policyengine_us_data/db/etl_national_targets.py` for `policy_data.db`, local
12-
H5 outputs, and validation inputs.
11+
- The appropriate `policyengine_us_data/db/etl_*.py` loader for
12+
`policy_data.db`, local H5 outputs, and validation inputs. National targets
13+
usually belong in `etl_national_targets.py`; state and local targets should
14+
use a state/local ETL module and must still be present in this DB path.
1315

1416
If the default calibration path uses `policyengine_us_data/calibration/target_config.yaml`
1517
with an `include:` list, also add the matching include rule there. A target can
@@ -19,10 +21,10 @@ from `target_config.yaml`.
1921
## Tests
2022

2123
Every target change should add or update tests that prove the target is wired
22-
through every active path. For manually sourced national targets, cover:
24+
through every active path. For manually sourced targets, cover:
2325

2426
- the ECPS loss matrix registration in `tests/unit/calibration/test_loss_targets.py`;
25-
- the DB ETL row in `tests/unit/test_etl_national_targets.py`;
27+
- the DB ETL row in the matching `tests/unit/test_etl_*.py` file;
2628
- the default calibration include rule in
2729
`tests/unit/calibration/test_target_config.py`;
2830
- any publication guard in `tests/unit/test_upload_completed_datasets.py` when

policyengine_us_data/calibration/target_config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ include:
5858
geo_level: state
5959
- variable: adjusted_gross_income
6060
geo_level: state
61+
# BEA regional wage totals. New macro targets like this must be wired into
62+
# both legacy utils/loss.py and the target DB ETL, or one publication path
63+
# will silently miss the constraint.
64+
- variable: employment_income_before_lsr
65+
geo_level: state
6166
- variable: rent
6267
geo_level: state
6368
- variable: spm_unit_count

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ABSOLUTE_ERROR_SCALE_TARGETS,
99
build_loss_matrix,
1010
get_target_error_normalisation,
11+
get_target_loss_weights,
1112
HardConcrete,
1213
print_reweighting_diagnostics,
1314
set_seeds,
@@ -506,7 +507,9 @@ def reweight(
506507
normalisation_factor = np.where(
507508
is_national, nation_normalisation_factor, state_normalisation_factor
508509
)
510+
target_loss_weights = get_target_loss_weights(target_names)
509511
normalisation_factor = torch.tensor(normalisation_factor, dtype=torch.float32)
512+
target_loss_weights = torch.tensor(target_loss_weights, dtype=torch.float32)
510513
targets_array = torch.tensor(targets_array, dtype=torch.float32)
511514
numerator_shift = torch.tensor(numerator_shift_np, dtype=torch.float32)
512515
error_denominator = torch.tensor(error_denominator_np, dtype=torch.float32)
@@ -525,6 +528,7 @@ def loss(weights):
525528
(estimate - targets_array + numerator_shift) / error_denominator
526529
) ** 2
527530
rel_error_normalized = inv_mean_normalisation * rel_error * normalisation_factor
531+
rel_error_normalized = rel_error_normalized * target_loss_weights
528532
if torch.isnan(rel_error_normalized).any():
529533
raise ValueError("Relative error contains NaNs")
530534
return rel_error_normalized.mean()

policyengine_us_data/db/create_field_valid_values.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def populate_field_valid_values(session: Session) -> None:
6969
# Static values for source field
7070
source_values = [
7171
("source", "Census ACS S0101", "survey"),
72+
("source", "BEA Regional SAINC4", "administrative"),
7273
("source", "IRS SOI", "administrative"),
7374
("source", "IRS EITC Central", "administrative"),
7475
("source", "CMS Marketplace", "administrative"),
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""ETL for BEA regional state wage calibration targets."""
2+
3+
import logging
4+
5+
import pandas as pd
6+
from sqlmodel import Session, create_engine
7+
8+
from policyengine_us_data.db.etl_national_targets import (
9+
BEA_NIPA_WAGES_AND_SALARIES_2024,
10+
_register_target_variable,
11+
_upsert_baseline_target,
12+
)
13+
from policyengine_us_data.storage import STORAGE_FOLDER
14+
from policyengine_us_data.utils.bea_regional import (
15+
BEA_STATE_WAGES_SOURCE,
16+
BEA_STATE_WAGES_SOURCE_URL,
17+
get_bea_state_wage_targets,
18+
)
19+
from policyengine_us_data.utils.db import etl_argparser, get_geographic_strata
20+
21+
logger = logging.getLogger(__name__)
22+
23+
TARGET_VARIABLE = "employment_income_before_lsr"
24+
25+
26+
def extract_bea_state_wage_targets(year: int) -> tuple[pd.DataFrame, int]:
27+
"""Extract BEA state wage targets scaled to the national NIPA total."""
28+
return get_bea_state_wage_targets(
29+
year,
30+
national_total=BEA_NIPA_WAGES_AND_SALARIES_2024,
31+
)
32+
33+
34+
def load_bea_state_wage_targets(
35+
targets: pd.DataFrame,
36+
*,
37+
target_year: int,
38+
source_year: int,
39+
) -> int:
40+
"""Load BEA state wage targets into state geographic strata."""
41+
if targets.empty:
42+
return 0
43+
44+
database_url = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}"
45+
engine = create_engine(database_url)
46+
loaded = 0
47+
48+
with Session(engine) as session:
49+
_register_target_variable(session, TARGET_VARIABLE)
50+
geo_strata = get_geographic_strata(session)
51+
state_strata = geo_strata.get("state", {})
52+
53+
for row in targets.itertuples(index=False):
54+
state_fips = int(row.state_fips)
55+
stratum_id = state_strata.get(state_fips)
56+
if stratum_id is None:
57+
logger.warning(
58+
"No geographic stratum found for state %s (FIPS %s), skipping",
59+
row.state_code,
60+
state_fips,
61+
)
62+
continue
63+
64+
_upsert_baseline_target(
65+
session,
66+
stratum_id=stratum_id,
67+
variable=TARGET_VARIABLE,
68+
period=target_year,
69+
value=float(row.employment_income_before_lsr),
70+
source=BEA_STATE_WAGES_SOURCE,
71+
notes=(
72+
"BEA SAINC4 line 50 wages and salaries by state, adjusted "
73+
"to a residence basis by allocating line 42's residence "
74+
"adjustment to wages in proportion to place-of-work "
75+
"net-compensation components, then scaled to the national "
76+
"BEA NIPA Table 2.1 wages and salaries target. "
77+
f"Source year: {source_year}; state: {row.state_code}; "
78+
f"raw residence-adjusted state wages: "
79+
f"${row.wages_and_salaries:,.0f}; "
80+
f"national scaling factor: {row.scale_factor:.8f}; "
81+
f"Source: {BEA_STATE_WAGES_SOURCE_URL}"
82+
),
83+
)
84+
loaded += 1
85+
86+
session.commit()
87+
88+
logger.info("Loaded %s BEA state wage targets", loaded)
89+
return loaded
90+
91+
92+
def main():
93+
logging.basicConfig(
94+
level=logging.INFO,
95+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
96+
)
97+
_, year = etl_argparser(
98+
"ETL for BEA regional state wage calibration targets",
99+
allow_year=True,
100+
)
101+
102+
targets, source_year = extract_bea_state_wage_targets(year)
103+
loaded = load_bea_state_wage_targets(
104+
targets,
105+
target_year=year,
106+
source_year=source_year,
107+
)
108+
109+
logger.info(
110+
"BEA State Wage Targets Summary:\n"
111+
" Source year: %s\n"
112+
" States loaded: %s\n"
113+
" Target total: $%.1fT",
114+
source_year,
115+
loaded,
116+
targets["employment_income_before_lsr"].sum() / 1e12,
117+
)
118+
119+
120+
if __name__ == "__main__":
121+
main()

policyengine_us_data/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"HardConcrete",
1414
"build_loss_matrix",
1515
"get_target_error_normalisation",
16+
"get_target_loss_weights",
1617
"print_reweighting_diagnostics",
1718
"set_seeds",
1819
]

0 commit comments

Comments
 (0)