Skip to content

Commit 630ee2a

Browse files
authored
Avoid cross-country imports in simulation worker
Set the simulation worker to avoid top-level country imports and pass Hugging Face credentials into Modal for private country data manifests.
1 parent 64acd89 commit 630ee2a

4 files changed

Lines changed: 23 additions & 24 deletions

File tree

.github/scripts/modal-sync-secrets.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ if [ -n "${GCP_CREDENTIALS_JSON:-}" ]; then
6262
--force || true
6363
fi
6464

65+
uv run modal secret create policyengine-data-credentials \
66+
"HUGGING_FACE_TOKEN=${HUGGING_FACE_TOKEN:-}" \
67+
--env="$MODAL_ENV" \
68+
--force || true
69+
6570
# Sync gateway auth config. The gateway runtime only needs issuer/audience and
6671
# the explicit requirement flag; client credentials stay on the GitHub side and
6772
# are only used to mint integration-test tokens.

.github/workflows/modal-deploy.reusable.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
5656
LOGFIRE_TOKEN: ${{ secrets.LOGFIRE_TOKEN }}
5757
GCP_CREDENTIALS_JSON: ${{ secrets.GCP_CREDENTIALS_JSON }}
58+
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
5859
GATEWAY_AUTH_ISSUER: ${{ secrets.GATEWAY_AUTH_ISSUER }}
5960
GATEWAY_AUTH_AUDIENCE: ${{ secrets.GATEWAY_AUTH_AUDIENCE }}
6061
GATEWAY_AUTH_CLIENT_ID: ${{ secrets.GATEWAY_AUTH_CLIENT_ID }}

projects/policyengine-api-simulation/src/modal/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def get_app_name(us_version: str, uk_version: str) -> str:
3939
# Secrets
4040
# GCP credentials are shared across environments (always from main)
4141
gcp_secret = modal.Secret.from_name("gcp-credentials", environment_name="main")
42+
data_secret = modal.Secret.from_name("policyengine-data-credentials")
4243
# Logfire secret is environment-specific
4344
logfire_secret = modal.Secret.from_name("policyengine-logfire")
4445

@@ -80,7 +81,7 @@ def configure_logfire(service_name: str = "policyengine-simulation"):
8081
timeout=3600,
8182
retries=0,
8283
max_containers=100,
83-
secrets=[gcp_secret, logfire_secret],
84+
secrets=[gcp_secret, data_secret, logfire_secret],
8485
)
8586
def run_simulation(params: dict) -> dict:
8687
"""
@@ -118,7 +119,7 @@ def run_simulation(params: dict) -> dict:
118119
timeout=3600,
119120
retries=0,
120121
max_containers=100,
121-
secrets=[gcp_secret, logfire_secret],
122+
secrets=[gcp_secret, data_secret, logfire_secret],
122123
)
123124
def run_budget_window_batch(params: dict) -> dict:
124125
"""Execute a multi-year budget-window batch orchestration."""

projects/policyengine-api-simulation/src/modal/simulation.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@
1010
import logging
1111
import os
1212
import tempfile
13+
import importlib
1314
from typing import Any, Iterator
1415

16+
# policyengine.core is imported for every simulation. Without this guard,
17+
# importing the package pulls both country modules into the process; a US run
18+
# can then fail before it starts if UK private-data credentials are absent.
19+
os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1")
20+
1521
try:
1622
from src.modal.telemetry import split_internal_payload
1723
except ModuleNotFoundError:
@@ -236,13 +242,11 @@ def group_subset(entity: str):
236242

237243

238244
def _country_module(country: str):
239-
import policyengine as pe
240-
241245
country = country.lower()
242246
if country == "us":
243-
return pe.us
247+
return importlib.import_module("policyengine.tax_benefit_models.us")
244248
if country == "uk":
245-
return pe.uk
249+
return importlib.import_module("policyengine.tax_benefit_models.uk")
246250
raise ValueError(f"Unsupported country: {country}")
247251

248252

@@ -327,18 +331,10 @@ def _budget_result(country: str, baseline, reform) -> dict[str, float]:
327331

328332

329333
def _poverty_result(country: str, baseline, reform) -> dict[str, list[dict[str, Any]]]:
330-
import policyengine as pe
331-
332-
if country == "us":
333-
baseline_poverty = pe.us.economic_impact_analysis(
334-
baseline, reform
335-
).baseline_poverty
336-
reform_poverty = pe.us.economic_impact_analysis(baseline, reform).reform_poverty
337-
else:
338-
baseline_poverty = pe.uk.economic_impact_analysis(
339-
baseline, reform
340-
).baseline_poverty
341-
reform_poverty = pe.uk.economic_impact_analysis(baseline, reform).reform_poverty
334+
country_module = _country_module(country)
335+
impact = country_module.economic_impact_analysis(baseline, reform)
336+
baseline_poverty = impact.baseline_poverty
337+
reform_poverty = impact.reform_poverty
342338

343339
return {
344340
"baseline": baseline_poverty.dataframe.to_dict("records"),
@@ -347,12 +343,8 @@ def _poverty_result(country: str, baseline, reform) -> dict[str, list[dict[str,
347343

348344

349345
def _analysis_result(country: str, baseline, reform) -> dict[str, Any]:
350-
import policyengine as pe
351-
352-
if country == "us":
353-
analysis = pe.us.economic_impact_analysis(baseline, reform)
354-
else:
355-
analysis = pe.uk.economic_impact_analysis(baseline, reform)
346+
country_module = _country_module(country)
347+
analysis = country_module.economic_impact_analysis(baseline, reform)
356348

357349
return {
358350
"decile_impacts": analysis.decile_impacts.dataframe.to_dict("records"),

0 commit comments

Comments
 (0)